diff --git a/src/operations.jl b/src/operations.jl index b0e51ff..c6ba2ce 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -110,13 +110,14 @@ function amr_get(df::DataFrame, sys::ODESystem, ::Val{:data}) end #--------------------------------------------------------------------# IntermediateResults callback -# Publish intermediate results to RabbitMQ with at least `every` seconds inbetween callbacks +# Publish intermediate results to RabbitMQ with at least `every` seconds in between callbacks mutable struct IntermediateResults last_callback::Dates.DateTime # Track the last time the callback was called every::Dates.TimePeriod # Callback frequency e.g. `Dates.Second(5)` id::String + iter::Int # Track how many iterations of the calibration have happened function IntermediateResults(id::String; every=Dates.Second(0)) - new(typemin(Dates.DateTime), every, id) + new(typemin(Dates.DateTime), every, id, 0) end end @@ -134,6 +135,17 @@ function (o::IntermediateResults)(integrator) EasyModelAnalysis.DifferentialEquations.u_modified!(integrator, false) end +# Intermediate results functor for calibrate +function (o::IntermediateResults)(p,lossval,ode_sol) + if o.last_callback + o.every ≤ Dates.now() + param_dict = Dict(parameters(ode_sol.prob.f.sys) .=> ode_sol.prob.p) + state_dict = Dict([state => ode_sol[state] for state in states(ode_sol.prob.f.sys)]) + o.iter = o.iter + 1 + publish_to_rabbitmq(; iter = o.iter, loss = lossval, sol_data = state_dict, params = param_dict, id=o.id) + end + + return false +end #----------------------------------------------------------------------# dataframe_with_observables function dataframe_with_observables(sol::ODESolution) sys = sol.prob.f.sys @@ -158,8 +170,7 @@ function Simulate(o::OperationRequest) end function get_callback(o::OperationRequest, ::Type{Simulate}) - DiscreteCallback((args...) -> true, IntermediateResults(o.id,every = Dates.Second(0)), - save_positions = (false,false)) + DiscreteCallback((args...) -> true, IntermediateResults(o.id,every = Dates.Second(0))) end # callback for Simulate requests @@ -184,11 +195,7 @@ end # callback for Calibrate requests function get_callback(o::OperationRequest, ::Type{Calibrate}) - function (p,lossval,ode_sol) - param_dict = Dict(parameters(ode_sol.prob.f.sys) .=> ode_sol.prob.p) - state_dict = Dict([state => ode_sol[state] for state in states(ode_sol.prob.f.sys)]) - publish_to_rabbitmq(; loss = lossval, sol_data = state_dict, params = param_dict, id=o.id) - end + IntermediateResults(o.id,every = Dates.Second(0)) end function Calibrate(o::OperationRequest)