diff --git a/src/operations.jl b/src/operations.jl index 2978c49..7e72638 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -355,12 +355,11 @@ function solve(o::EnsembleCalibrate; callback) data = o.df - sol_maps_for_cal = Symbol.(names(data)) - - datacal_pairs = [state => data[!,first(values(state.metadata))[2]] for state in states(systems[o.model_ids[1]]) if first(values(state.metadata))[2] in sol_maps_for_cal] - - weights = EasyModelAnalysis.ensemble_weights(sol,datacal_pairs) - DataFrame("Weights" => weights) + data_pairs = [Symbol(name) => data[:,name] for name in names(data)] + data_pairs = filter(x -> x[1] != :timestamp ,data_pairs) + sol_mappings_list = [o.sol_mappings[id] for id in model_ids] + weights = SimulationService.ensemble_weights(sol,data_pairs,sol_mappings_list) + DataFrame(model_ids .=> weights) end # struct Ensemble <: Operation @@ -406,6 +405,7 @@ const route2operation_type = Dict( "ensemble-calibrate" => EnsembleCalibrate ) +# modified from EasyModelAnalysis.jl function sciml_service_l2loss(pvals, (prob, pkeys, data)::Tuple{Vararg{Any, 3}}) p = Pair.(pkeys, pvals) ts = first.(last.(data)) @@ -421,3 +421,12 @@ function sciml_service_l2loss(pvals, (prob, pkeys, data)::Tuple{Vararg{Any, 3}}) end return tot_loss, sol, ts end + +# assumes data is given in the form column_label => data, need sol_mappings to be of form column_label => observable +# modified from EasyModelAnalysis.jl +function ensemble_weights(sol::SciMLBase.EnsembleSolution, data_ensem, sol_mappings_list) + col = first.(data_ensem) + predictions = reduce(vcat, reduce(hcat,[sol[i][Symbol(sol_mappings_list[i][s])] for i in 1:length(sol)]) for s in col) + data = reduce(vcat, [data_ensem[i][2] isa Tuple ? data_ensem[i][2][2] : data_ensem[i][2] for i in 1:length(data_ensem)]) + weights = predictions \ data +end diff --git a/test/runtests.jl b/test/runtests.jl index 962d932..d1120cd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -222,21 +222,21 @@ end end @testset "ensemble-calibrate" begin - amrfiles = [SimulationService.get_json("https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/SEIRD_base_model01_petrinet.json"), - SimulationService.get_json("https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/SEIRHD_base_model01_petrinet.json")] - - amrs = amrfiles - + # more complex ensemble_calibrate + amrs = [SimulationService.get_json("https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/sirhd.json"), + SimulationService.get_json("https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/seiarhds.json")] + obj = ( - model_configs = map(1:2) do i - (id="model_config_id_$i", weight = i / sum(1:2), solution_mappings = (I = "I", R = "R", S = "S")) - end, + model_configs = [ + (id ="sirhd", weight = 1/3, solution_mappings = (Infected = "Infections", Hospitalizations = "hospitalized_population")), + (id = "seirhds", weight = 2/3, solution_mappings = (Infected = "Cases", Hospitalizations = "hospitalized_population"))] + , models = amrs, timespan = (start = 0, var"end" = 40), engine = "sciml", extra = (; num_samples = 40) ) - # do ensemble-simulate + o = OperationRequest() o.route = "ensemble-simulate" o.obj = JSON3.read(JSON3.write(obj)) @@ -244,7 +244,7 @@ end o.timespan = (0,40) en = SimulationService.EnsembleSimulate(o) sim_en_sol = SimulationService.solve(en, callback = nothing) - # create ensemble-calibrate + # ensemble part o = OperationRequest() o.route = "ensemble-calibrate" o.obj = JSON3.read(JSON3.write(obj)) @@ -253,8 +253,7 @@ end o.df = sim_en_sol en_cal = SimulationService.EnsembleCalibrate(o) cal_sol = SimulationService.solve(en_cal,callback = nothing) - @test cal_sol[!,:Weights] ≈ [0.3333333333333333,0.6666666666666666] - + @test cal_sol[!,:sirhd] ≈ [0.3333333333333333] && cal_sol[!,:seirhds] ≈ [0.6666666666666666] end @testset "Real Calibrate Payload" begin