Skip to content

Commit

Permalink
Update tests for AMR changes (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
joshday authored Sep 7, 2023
1 parent d8cda9c commit c59c148
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 198 deletions.
274 changes: 158 additions & 116 deletions Manifest.toml

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"

[compat]
julia = "1.9"
AMQPClient = "0.5"
CSV = "0.10"
DataFrames = "1.6"
Expand All @@ -47,7 +48,6 @@ Oxygen = "1.1"
SciMLBase = "1.93"
SwaggerMarkdown = "0.2"
YAML = "0.4"
julia = "1.9"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
12 changes: 7 additions & 5 deletions examples/SIRModelConfiguartion.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
{
"id": "55ca595f-940d-458c-9c73-7c315a10b559",
"name": "Default config",
"description": "Default config",
"timestamp": "2023-07-13T20:18:12",
"model_id": "0984a5a9-6438-4041-aa11-2f9ea8fc9d4a",
"header": {
"name": "Default config",
"description": "Default config",
"timestamp": "2023-07-13T20:18:12",
"model_id": "0984a5a9-6438-4041-aa11-2f9ea8fc9d4a"
},
"configuration": {
"id": "0984a5a9-6438-4041-aa11-2f9ea8fc9d4a",
"name": "SIRs",
Expand Down Expand Up @@ -639,4 +641,4 @@
"calibrated": false,
"calibration": null,
"calibration_score": null
}
}
12 changes: 7 additions & 5 deletions examples/calibrate_example1/BIOMD0000000955_askenet.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
{
"name": "Giordano2020 - SIDARTHE model of COVID-19 spread in Italy",
"schema": "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/petrinet_v0.5/petrinet/petrinet_schema.json",
"schema_name": "petrinet",
"description": "Giordano2020 - SIDARTHE model of COVID-19 spread in Italy",
"model_version": "0.1",
"header": {
"name": "Giordano2020 - SIDARTHE model of COVID-19 spread in Italy",
"schema": "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/petrinet_v0.5/petrinet/petrinet_schema.json",
"schema_name": "petrinet",
"description": "Giordano2020 - SIDARTHE model of COVID-19 spread in Italy",
"model_version": "0.1"
},
"properties": {},
"model": {
"states": [
Expand Down
12 changes: 7 additions & 5 deletions examples/calibrate_example2/SIRModelConfiguartion.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
{
"id": "55ca595f-940d-458c-9c73-7c315a10b559",
"name": "Default config",
"description": "Default config",
"timestamp": "2023-07-13T20:18:12",
"model_id": "0984a5a9-6438-4041-aa11-2f9ea8fc9d4a",
"header": {
"name": "Default config",
"description": "Default config",
"timestamp": "2023-07-13T20:18:12",
"model_id": "0984a5a9-6438-4041-aa11-2f9ea8fc9d4a"
},
"configuration": {
"id": "0984a5a9-6438-4041-aa11-2f9ea8fc9d4a",
"name": "SIRs",
Expand Down Expand Up @@ -639,4 +641,4 @@
"calibrated": false,
"calibration": null,
"calibration_score": null
}
}
14 changes: 8 additions & 6 deletions examples/sir_calibrate/sir.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
{
"id": "5bcf7464-5dce-4af5-b8b3-747328f1525e",
"timestamp": "2023-07-10 18:41:39",
"name": "SIRs",
"description": "SIR model",
"username": null,
"id": "5bcf7464-5dce-4af5-b8b3-747328f1525e",
"header": {
"timestamp": "2023-07-10 18:41:39",
"name": "SIRs",
"description": "SIR model",
"username": null
},
"model": {
"states": [
{
Expand Down Expand Up @@ -636,4 +638,4 @@
]
},
"schema": "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/petrinet_v0.5/petrinet/petrinet_schema.json"
}
}
24 changes: 20 additions & 4 deletions src/SimulationService.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import HTTP
import InteractiveUtils: subtypes
import JobSchedulers
import JSON3
import JSONSchema
import JSONSchema
import LinearAlgebra: norm
import MathML
import ModelingToolkit: @parameters, substitute, Differential, Num, @variables, ODESystem, ODEProblem, ODESolution, structural_simplify, states, observed
Expand All @@ -40,6 +40,7 @@ const simulation_schema = Ref{JSON3.Object}()
const petrinet_schema = Ref{JSON3.Object}()
const petrinet_JSONSchema_object = Ref{JSONSchema.Schema}()
const server_url = Ref{String}()
const mock_tds = Ref{Dict{String, Dict{String, JSON3.Object}}}() # e.g. "model" => "model_id" => model

#-----# Environmental Variables:
# Server configuration
Expand All @@ -64,6 +65,7 @@ function __init__()
simulation_schema[] = get_json("https://raw.githubusercontent.com/DARPA-ASKEM/simulation-api-spec/main/schemas/simulation.json")
petrinet_schema[] = get_json("https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/petrinet_schema.json")
petrinet_JSONSchema_object[] = JSONSchema.Schema(petrinet_schema[])

HOST[] = get(ENV, "SIMSERVICE_HOST", "0.0.0.0")
PORT[] = parse(Int, get(ENV, "SIMSERVICE_PORT", "8080"))
ENABLE_TDS[] = get(ENV, "SIMSERVICE_ENABLE_TDS", "true") == "true" #
Expand All @@ -80,8 +82,9 @@ function __init__()
(; MECHANISM = "AMQPLAIN", LOGIN=RABBITMQ_LOGIN, PASSWORD=RABBITMQ_PASSWORD)
)
conn = AMQPClient.connection(; virtualhost="/", host="localhost", port=RABBITMQ_PORT, auth_params)
@info typeof(AMQPClient.channel(conn, AMQPClient.UNUSED_CHANNEL, true))

rabbitmq_channel[] = AMQPClient.channel(conn, AMQPClient.UNUSED_CHANNEL, true)
AMQPClient.queue_declare(rabbitmq_channel[], "sciml-queue")
end

v = Pkg.Types.read_project("Project.toml").version
Expand Down Expand Up @@ -131,6 +134,19 @@ get_json(url::String) = JSON3.read(HTTP.get(url, json_header).body)

timestamp() = Dates.format(now(), "yyyy-mm-ddTHH:MM:SS")

# Run some code with a running server
function with_server(f::Function; wait=1)
try
start!()
sleep(wait)
url = SimulationService.server_url[]
f(url)
catch ex
rethrow(ex)
finally
stop!()
end
end

#-----------------------------------------------------------------------------# job endpoints
get_job(id::String) = JobSchedulers.job_query(jobhash(id))
Expand Down Expand Up @@ -223,15 +239,15 @@ function OperationRequest(req::HTTP.Request, route::String)
# Checks if the JSON model is valid against the petrinet schema
# If not valid, produces a warning saying why
if !isnothing(o.model)
valid_against_schema = JSONSchema.validate(petrinet_JSONSchema_object[],o.model)
valid_against_schema = JSONSchema.validate(petrinet_JSONSchema_object[],o.model)
if !isnothing(valid_against_schema)
@warn "Object not valid against schema: $(valid_against_schema)"
end
end

if !isnothing(o.models)
for model in o.models
valid_against_schema = JSONSchema.validate(petrinet_JSONSchema_object[],model)
valid_against_schema = JSONSchema.validate(petrinet_JSONSchema_object[],model)
if !isnothing(valid_against_schema)
@warn "Object not valid against schema: $(valid_against_schema)"
end
Expand Down
18 changes: 10 additions & 8 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ function amr_get(amr::JSON3.Object, ::Type{ODESystem})
push!(eqs, ofunc ~ expr)
end

sys = structural_simplify(ODESystem(eqs, t, allfuncs, paramvars; defaults = [statefuncs .=> initial_vals; sym_defs], name=Symbol(amr.name)))
defaults = [statefuncs .=> initial_vals; sym_defs]
name = Symbol(amr.header.name)
sys = structural_simplify(ODESystem(eqs, t, allfuncs, paramvars; defaults, name))
@info "amr_get(amr, ODESystem) --> $sys"

sys
Expand Down Expand Up @@ -154,10 +156,10 @@ function Simulate(o::OperationRequest)
Simulate(sys, o.timespan)
end

function solve(op::Simulate; kw...)
function solve(op::Simulate; callback)
# joshday: What does providing `u0 = []` do? Don't we know what u0 is from AMR?
prob = ODEProblem(op.sys, [], op.timespan)
sol = solve(prob; progress = true, progress_steps = 1, saveat=1, kw...)
sol = solve(prob; progress = true, progress_steps = 1, saveat=1, callback)
@info "Timesteps returned are: $(sol.t)"
dataframe_with_observables(sol)
end
Expand Down Expand Up @@ -207,7 +209,7 @@ function solve(o::Calibrate; callback)

probs = [EasyModelAnalysis.remake(prob, p = Pair.(first.(p_posterior), getindex.(pvalues,i))) for i in 1:length(p_posterior[1][2])]
enprob = EasyModelAnalysis.EnsembleProblem(probs)
ensol = solve(enprob, saveat = 1)
ensol = solve(enprob; saveat = 1, callback)
outs = map(1:length(probs)) do i
mats = stack(ensol[i][statenames])'
headers = string.("ensemble",i,"_", statenames)
Expand All @@ -231,7 +233,7 @@ function solve(o::Calibrate; callback)
end

newprob = EasyModelAnalysis.DifferentialEquations.remake(prob, p=fit)
sol = EasyModelAnalysis.DifferentialEquations.solve(newprob; saveat = 1)
sol = EasyModelAnalysis.DifferentialEquations.solve(newprob; saveat = 1, callback)
dfsim = DataFrame(hcat(sol.t,stack(sol[statenames])'), :auto)
rename!(dfsim, ["timestamp";string.(statenames)])

Expand Down Expand Up @@ -273,7 +275,7 @@ function solve(o::Ensemble{Simulate}; callback)
systems = [sim.sys for sim in o.operations]
probs = ODEProblem.(systems, Ref([]), Ref(o.operations[1].timespan))
enprob = EMA.EnsembleProblem(probs)
sol = solve(enprob; saveat = 1);
sol = solve(enprob; saveat = 1, callback);
weights = [0.2, 0.5, 0.3]
data = [x => vec(sum(stack(o.weights .* sol[:,x]), dims = 2)) for x in error("What goes here?")]
end
Expand All @@ -293,7 +295,7 @@ function solve(o::Ensemble{Calibrate}; callback)

# forecast_probs = [EMA.remake(enprobs.prob[i]; tspan = (t_train[1],t_forecast[end])) for i in 1:length(enprobs.prob)]
# fit_enprob = EMA.EnsembleProblem(forecast_probs)
# sol = solve(fit_enprob; saveat = o.t_forecast);
# sol = solve(fit_enprob; saveat = o.t_forecast, callback);

# soldata = DataFrame([sol.t; Matrix(sol[names])'])

Expand Down Expand Up @@ -328,7 +330,7 @@ end

# forecast_probs = [EMA.remake(enprobs.prob[i]; tspan = (t_train[1],t_forecast[end])) for i in 1:length(enprobs.prob)]
# fit_enprob = EMA.EnsembleProblem(forecast_probs)
# sol = solve(fit_enprob; saveat = o.t_forecast);
# sol = solve(fit_enprob; saveat = o.t_forecast, callback);

# soldata = DataFrame([sol.t; Matrix(sol[names])'])

Expand Down
90 changes: 42 additions & 48 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ end
obj = SimulationService.get_json(json_url)
sys = SimulationService.amr_get(obj, ODESystem)
op = Simulate(sys, (0.0, 99.0))
df = solve(op)
df = solve(op; callback = nothing)
@test df isa DataFrame
@test extrema(df.timestamp) == (0.0, 99.0)
end
Expand All @@ -141,9 +141,9 @@ end
ode_method = nothing
o = SimulationService.Calibrate(sys, (0.0, 89.0), priors, data, num_chains, num_iterations, calibrate_method, ode_method)

dfsim, dfparam = SimulationService.solve(o; callback = nothing)
dfsim, dfparam = solve(o; callback = nothing)

statenames = [states(o.sys);getproperty.(observed(o.sys), :lhs)]
statenames = [states(o.sys); getproperty.(observed(o.sys), :lhs)]
@test names(dfsim) == vcat("timestamp",reduce(vcat,[string.("ensemble",i,"_", statenames) for i in 1:size(dfsim,2)÷length(statenames)]))
@test names(dfparam) == string.(parameters(sys))

Expand Down Expand Up @@ -214,58 +214,52 @@ end

#-----------------------------------------------------------------------------# test routes
@testset "Server Routes" begin
start!()

url = SimulationService.server_url[]

sleep(3) # Give server a chance to start

@testset "/" begin
res = HTTP.get(url)
@test res.status == 200
@test JSON3.read(res.body).status == "ok"
end
SimulationService.with_server() do url
@testset "/" begin
res = HTTP.get(url)
@test res.status == 200
@test JSON3.read(res.body).status == "ok"
end

@testset "/docs" begin
res = HTTP.get("$url/docs")
@test res.status == 200
end
@testset "/docs" begin
res = HTTP.get("$url/docs")
@test res.status == 200
end

# Check the status of a job until it finishes
function test_until_done(id::String, every=2)
t = now()
while true
st = get_json("$url/status/$id").status
@info "status from job $(repr(id)) - ($(round(now() - t, Dates.Second))): $st"
st in ["queued", "running", "complete"] && @test true
st in ["failed", "error"] && (@test false; break)
st == "complete" && break
sleep(every)
# Check the status of a job until it finishes
function test_until_done(id::String, every=2)
t = now()
while true
st = get_json("$url/status/$id").status
@info "status from job $(repr(id)) - ($(round(now() - t, Dates.Second))): $st"
st in ["queued", "running", "complete"] && @test true
st in ["failed", "error"] && (@test false; break)
st == "complete" && break
sleep(every)
end
end
end

@testset "/simulate" begin
for body in simulate_payloads
res = HTTP.post("$url/simulate", ["Content-Type" => "application/json"]; body)
@test res.status == 201
id = JSON3.read(res.body).simulation_id
test_until_done(id)
@test SimulationService.last_operation[].result isa DataFrame
@testset "/simulate" begin
for body in simulate_payloads
res = HTTP.post("$url/simulate", ["Content-Type" => "application/json"]; body)
@test res.status == 201
id = JSON3.read(res.body).simulation_id
test_until_done(id)
@test SimulationService.last_operation[].result isa DataFrame
end
end
end

@testset "/calibrate" begin
for body in calibrate_payloads
res = HTTP.post("$url/calibrate", ["Content-Type" => "application/json"]; body)
@test res.status == 201
id = JSON3.read(res.body).simulation_id
test_until_done(id, 5)
@testset "/calibrate" begin
for body in calibrate_payloads
res = HTTP.post("$url/calibrate", ["Content-Type" => "application/json"]; body)
@test res.status == 201
id = JSON3.read(res.body).simulation_id
test_until_done(id, 5)
end
end
end

@testset "/ensemble" begin
@test true # TODO
@testset "/ensemble" begin
@test true # TODO
end
end

stop!()
end

0 comments on commit c59c148

Please sign in to comment.