From a131977a5d2f80ed720f0105cf7544ba530ba299 Mon Sep 17 00:00:00 2001
From: Five Grant <5@fivegrant.com>
Date: Mon, 18 Mar 2024 16:41:05 -0500
Subject: [PATCH 1/3] Reinclude old endpoint
---
pyproject.toml | 2 +-
service/api.py | 13 +-
service/execute.py | 1 +
service/models/operations/__init__.py | 1 +
.../models/operations/ensemble_calibrate.py | 80 ++++
.../398d2a33-1cbe-44cf-9b66-a6976b947809.json | 405 ----------------
.../ensemble-calibrate/input/ensemble.csv | 2 +-
.../fab58d86-cf1e-4990-809a-58c029545a3a.json | 282 -----------
.../ensemble-calibrate/input/left.json | 430 +++++++++++++++++
.../ensemble-calibrate/input/request.json | 19 +-
.../ensemble-calibrate/input/right.json | 444 ++++++++++++++++++
tests/integration/test_ensemble_calibrate.py | 73 +++
tests/test_conversions.py | 32 ++
13 files changed, 1077 insertions(+), 707 deletions(-)
create mode 100644 service/models/operations/ensemble_calibrate.py
delete mode 100644 tests/examples/ensemble-calibrate/input/398d2a33-1cbe-44cf-9b66-a6976b947809.json
delete mode 100644 tests/examples/ensemble-calibrate/input/fab58d86-cf1e-4990-809a-58c029545a3a.json
create mode 100644 tests/examples/ensemble-calibrate/input/left.json
create mode 100644 tests/examples/ensemble-calibrate/input/right.json
create mode 100644 tests/integration/test_ensemble_calibrate.py
diff --git a/pyproject.toml b/pyproject.toml
index c62ab0e..f7990f5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,7 +40,7 @@ httpx = "^0.24.1"
[tool.poe.tasks]
-install-pyciemss = "pip install --no-cache-dir pyro-ppl==1.8.6 git+https://github.com/ciemss/pyciemss.git@d6838e72bdc145b2f87ab9e33e220eb84fd87e87 --use-pep517"
+install-pyciemss = "pip install --no-cache-dir git+https://github.com/ciemss/pyciemss.git@1fc62b0d4b0870ca992514ad7a9b7a09a175ce44 --use-pep517"
[tool.pytest.ini_options]
diff --git a/service/api.py b/service/api.py
index 8faedb8..8d0d62d 100644
--- a/service/api.py
+++ b/service/api.py
@@ -2,7 +2,7 @@
import logging
import os
-from fastapi import FastAPI, Depends, HTTPException
+from fastapi import FastAPI, Depends
from fastapi.middleware.cors import CORSMiddleware
from service.models import (
@@ -11,6 +11,7 @@
Calibrate,
Simulate,
EnsembleSimulate,
+ EnsembleCalibrate,
Optimize,
StatusSimulationIdGetResponse,
)
@@ -21,6 +22,7 @@
"simulate": Simulate,
"calibrate": Calibrate,
"ensemble-simulate": EnsembleSimulate,
+ "ensemble-calibrate": EnsembleCalibrate,
"optimize": Optimize,
}
@@ -111,12 +113,3 @@ def operate(
operate = make_operate(operation_name)
registrar(operate)
-
-
-@app.get("/ensemble-calibrate", response_model=StatusSimulationIdGetResponse)
-def ensemble_calibrate_not_yet_implemented():
- """
- DO NOT USE. Placeholder for `ensemble-calibrate` endpoint.
- This will be reimplemented in the future.
- """
- raise HTTPException(status=501, detail="Not yet reimplemented")
diff --git a/service/execute.py b/service/execute.py
index 5847689..7657d40 100644
--- a/service/execute.py
+++ b/service/execute.py
@@ -11,6 +11,7 @@
sample,
calibrate,
ensemble_sample,
+ ensemble_calibrate,
optimize,
)
diff --git a/service/models/operations/__init__.py b/service/models/operations/__init__.py
index d7c670c..0a206fc 100644
--- a/service/models/operations/__init__.py
+++ b/service/models/operations/__init__.py
@@ -1,4 +1,5 @@
from models.operations.simulate import Simulate
from models.operations.calibrate import Calibrate
from models.operations.ensemble_simulate import EnsembleSimulate
+from models.operations.ensemble_calibrate import EnsembleCalibrate
from models.operations.optimize import Optimize
diff --git a/service/models/operations/ensemble_calibrate.py b/service/models/operations/ensemble_calibrate.py
new file mode 100644
index 0000000..3ecf549
--- /dev/null
+++ b/service/models/operations/ensemble_calibrate.py
@@ -0,0 +1,80 @@
+from __future__ import annotations
+
+from typing import ClassVar, List, Dict, Any
+
+import socket
+import logging
+from pydantic import BaseModel, Field, Extra
+import torch # TODO: Do not use Torch in PyCIEMSS Library interface
+
+from pika.exceptions import AMQPConnectionError
+
+from models.base import Dataset, OperationRequest, Timespan, ModelConfig
+from models.converters import convert_to_solution_mapping
+from utils.rabbitmq import gen_rabbitmq_hook
+from utils.tds import fetch_dataset, fetch_model
+
+
+class EnsembleCalibrateExtra(BaseModel):
+ noise_model: str = "normal"
+ noise_model_kwargs: Dict[str, Any] = {"scale": 0.1}
+ solver_method: str = "dopri5"
+ solver_options: Dict[str, Any] = {}
+ num_iterations: int = 1000
+ lr: float = 0.03
+ verbose: bool = False
+ num_particles: int = 1
+ deterministic_learnable_parameters: List[str] = []
+
+
+class EnsembleCalibrate(OperationRequest):
+ pyciemss_lib_function: ClassVar[str] = "ensemble_calibrate"
+ model_configs: List[ModelConfig] = Field(
+ [],
+ example=[],
+ )
+ timespan: Timespan
+ dataset: Dataset = None
+
+ step_size: float = 1.0
+
+ extra: EnsembleCalibrateExtra = Field(
+ None,
+ description="optional extra system specific arguments for advanced use cases",
+ )
+
+ def gen_pyciemss_args(self, job_id):
+ weights = torch.tensor([config.weight for config in self.model_configs])
+ solution_mappings = [
+ convert_to_solution_mapping(config) for config in self.model_configs
+ ]
+ amr_paths = [fetch_model(config.id, job_id) for config in self.model_configs]
+ dataset_path = fetch_dataset(self.dataset.dict(), job_id)
+
+ try:
+ hook = gen_rabbitmq_hook(job_id)
+ except (socket.gaierror, AMQPConnectionError):
+ logging.warning(
+ "%s: Failed to connect to RabbitMQ. Unable to log progress", job_id
+ )
+
+ def hook(progress, _loss):
+ progress = progress / 10 # TODO: Fix magnitude of progress upstream
+ if progress == int(progress):
+ logging.info(f"Calibration is {progress}% complete")
+ return None
+
+ return {
+ "model_paths_or_jsons": amr_paths,
+ "solution_mappings": solution_mappings,
+ "data_path": dataset_path,
+ "start_time": self.timespan.start,
+ # "end_time": self.timespan.end,
+ "dirichlet_alpha": weights,
+ "progress_hook": hook,
+ # "visual_options": True,
+ **self.extra.dict(),
+ }
+
+ class Config:
+ extra = Extra.forbid
diff --git a/tests/examples/ensemble-calibrate/input/398d2a33-1cbe-44cf-9b66-a6976b947809.json b/tests/examples/ensemble-calibrate/input/398d2a33-1cbe-44cf-9b66-a6976b947809.json
deleted file mode 100644
index 7baad98..0000000
--- a/tests/examples/ensemble-calibrate/input/398d2a33-1cbe-44cf-9b66-a6976b947809.json
+++ /dev/null
@@ -1,405 +0,0 @@
-{
- "id": "398d2a33-1cbe-44cf-9b66-a6976b947809",
- "name": "Left",
- "description": "The left one",
- "timestamp": "2023-08-03T18:53:20",
- "model_id": "b4a38a62-16d9-4317-8cfc-510b37152776",
- "configuration": {
- "header":{
- "name": "Model",
- "schema": "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/petrinet_v0.5/petrinet/petrinet_schema.json",
- "description": "Model",
- "model_version": "0.1"
- },
- "model": {
- "states": [
- {
- "id": "susceptible_population",
- "name": "susceptible_population",
- "grounding": {
- "identifiers": {
- "ido": "0000514"
- },
- "modifiers": {}
- }
- },
- {
- "id": "symptomatic_population",
- "name": "symptomatic_population",
- "grounding": {
- "identifiers": {
- "ido": "0000573"
- },
- "modifiers": {}
- }
- },
- {
- "id": "asymptomatic_population",
- "name": "asymptomatic_population",
- "grounding": {
- "identifiers": {
- "ido": "0000569"
- },
- "modifiers": {}
- }
- },
- {
- "id": "exposed_population",
- "name": "exposed_population",
- "grounding": {
- "identifiers": {
- "ido": "0000594"
- },
- "modifiers": {}
- }
- },
- {
- "id": "recovered_population",
- "name": "recovered_population",
- "grounding": {
- "identifiers": {
- "ido": "0000592"
- },
- "modifiers": {}
- }
- },
- {
- "id": "hospitalized_population",
- "name": "hospitalized_population",
- "grounding": {
- "identifiers": {
- "ncit": "C25179"
- },
- "modifiers": {}
- }
- },
- {
- "id": "deceased_population",
- "name": "deceased_population",
- "grounding": {
- "identifiers": {
- "ncit": "C168970"
- },
- "modifiers": {}
- }
- }
- ],
- "transitions": [
- {
- "id": "t1",
- "input": [
- "symptomatic_population",
- "asymptomatic_population",
- "susceptible_population"
- ],
- "output": [
- "symptomatic_population",
- "asymptomatic_population",
- "exposed_population"
- ],
- "properties": {
- "name": "t1"
- }
- },
- {
- "id": "t2",
- "input": [
- "exposed_population"
- ],
- "output": [
- "symptomatic_population"
- ],
- "properties": {
- "name": "t2"
- }
- },
- {
- "id": "t3",
- "input": [
- "exposed_population"
- ],
- "output": [
- "asymptomatic_population"
- ],
- "properties": {
- "name": "t3"
- }
- },
- {
- "id": "t4",
- "input": [
- "symptomatic_population"
- ],
- "output": [
- "recovered_population"
- ],
- "properties": {
- "name": "t4"
- }
- },
- {
- "id": "t5",
- "input": [
- "symptomatic_population"
- ],
- "output": [
- "hospitalized_population"
- ],
- "properties": {
- "name": "t5"
- }
- },
- {
- "id": "t6",
- "input": [
- "symptomatic_population"
- ],
- "output": [
- "deceased_population"
- ],
- "properties": {
- "name": "t6"
- }
- },
- {
- "id": "t7",
- "input": [
- "asymptomatic_population"
- ],
- "output": [
- "recovered_population"
- ],
- "properties": {
- "name": "t7"
- }
- },
- {
- "id": "t8",
- "input": [
- "hospitalized_population"
- ],
- "output": [
- "recovered_population"
- ],
- "properties": {
- "name": "t8"
- }
- },
- {
- "id": "t9",
- "input": [
- "hospitalized_population"
- ],
- "output": [
- "deceased_population"
- ],
- "properties": {
- "name": "t9"
- }
- },
- {
- "id": "t10",
- "input": [
- "recovered_population"
- ],
- "output": [
- "susceptible_population"
- ],
- "properties": {
- "name": "t10"
- }
- }
- ]
- },
- "semantics": {
- "ode": {
- "rates": [
- {
- "target": "t1",
- "expression": "beta*susceptible_population*(asymptomatic_population + delta*symptomatic_population)/total_population",
- "expression_mathml": "betasusceptible_populationasymptomatic_populationdeltasymptomatic_populationtotal_population"
- },
- {
- "target": "t2",
- "expression": "exposed_population*pS/alpha",
- "expression_mathml": "exposed_populationpSalpha"
- },
- {
- "target": "t3",
- "expression": "exposed_population*(1 - pS)/alpha",
- "expression_mathml": "exposed_population1pSalpha"
- },
- {
- "target": "t4",
- "expression": "gamma*symptomatic_population*(-dnh - hosp + 1)",
- "expression_mathml": "gammasymptomatic_populationdnhhosp1"
- },
- {
- "target": "t5",
- "expression": "gamma*hosp*symptomatic_population",
- "expression_mathml": "gammahospsymptomatic_population"
- },
- {
- "target": "t6",
- "expression": "dnh*gamma*symptomatic_population",
- "expression_mathml": "dnhgammasymptomatic_population"
- },
- {
- "target": "t7",
- "expression": "asymptomatic_population*gamma",
- "expression_mathml": "asymptomatic_populationgamma"
- },
- {
- "target": "t8",
- "expression": "hospitalized_population*(1 - dh)/los",
- "expression_mathml": "hospitalized_population1dhlos"
- },
- {
- "target": "t9",
- "expression": "dh*hospitalized_population/los",
- "expression_mathml": "dhhospitalized_populationlos"
- },
- {
- "target": "t10",
- "expression": "recovered_population/tau",
- "expression_mathml": "recovered_populationtau"
- }
- ],
- "initials": [
- {
- "target": "susceptible_population",
- "expression": "99999.0000000000",
- "expression_mathml": "99999.0"
- },
- {
- "target": "symptomatic_population",
- "expression": "1.00000000000000",
- "expression_mathml": "1.0"
- },
- {
- "target": "asymptomatic_population",
- "expression": "0.0",
- "expression_mathml": "0.0"
- },
- {
- "target": "exposed_population",
- "expression": "0.0",
- "expression_mathml": "0.0"
- },
- {
- "target": "recovered_population",
- "expression": "0.0",
- "expression_mathml": "0.0"
- },
- {
- "target": "hospitalized_population",
- "expression": "0.0",
- "expression_mathml": "0.0"
- },
- {
- "target": "deceased_population",
- "expression": "0.0",
- "expression_mathml": "0.0"
- }
- ],
- "parameters": [
- {
- "id": "beta",
- "value": 0.55,
- "distribution": {
- "type": "Uniform1",
- "parameters": {
- "minimum": 0.5,
- "maximum": 0.6
- }
- }
- },
- {
- "id": "delta",
- "value": 1.5,
- "distribution": {
- "type": "Uniform1",
- "parameters": {
- "minimum": 1,
- "maximum": 2
- }
- }
- },
- {
- "id": "total_population",
- "value": 100000
- },
- {
- "id": "alpha",
- "value": 4,
- "distribution": {
- "type": "Uniform1",
- "parameters": {
- "minimum": 3,
- "maximum": 5
- }
- }
- },
- {
- "id": "pS",
- "value": 0.7
- },
- {
- "id": "gamma",
- "value": 0.2
- },
- {
- "id": "hosp",
- "value": 0.1
- },
- {
- "id": "dnh",
- "value": 0.001
- },
- {
- "id": "dh",
- "value": 0.1
- },
- {
- "id": "los",
- "value": 7
- },
- {
- "id": "tau",
- "value": 30
- }
- ],
- "observables": [
- {
- "id": "Cases",
- "name": "Cases",
- "expression": "symptomatic_population + asymptomatic_population",
- "expression_mathml": "asymptomaticpopulationsymptomaticpopuation"
- }
- ],
- "time": {
- "id": "t"
- }
- }
- },
- "metadata": {
- "annotations": {
- "license": null,
- "authors": [],
- "references": [],
- "time_scale": null,
- "time_start": null,
- "time_end": null,
- "locations": [],
- "pathogens": [],
- "diseases": [],
- "hosts": [],
- "model_types": []
- }
- }
- },
- "amr_configuration": null,
- "calibrated": false,
- "calibration": null,
- "calibration_score": null
-}
\ No newline at end of file
diff --git a/tests/examples/ensemble-calibrate/input/ensemble.csv b/tests/examples/ensemble-calibrate/input/ensemble.csv
index 2533420..a57f834 100644
--- a/tests/examples/ensemble-calibrate/input/ensemble.csv
+++ b/tests/examples/ensemble-calibrate/input/ensemble.csv
@@ -1,4 +1,4 @@
-Timestep,Infected,Hospitalizations
+tstep,Infected,Hospitalizations
0.0,1.0,1.9617580129804857e-12
1.0,0.9480640292167664,0.01719011925160885
2.0,1.066840648651123,0.031743988394737244
diff --git a/tests/examples/ensemble-calibrate/input/fab58d86-cf1e-4990-809a-58c029545a3a.json b/tests/examples/ensemble-calibrate/input/fab58d86-cf1e-4990-809a-58c029545a3a.json
deleted file mode 100644
index 57a667c..0000000
--- a/tests/examples/ensemble-calibrate/input/fab58d86-cf1e-4990-809a-58c029545a3a.json
+++ /dev/null
@@ -1,282 +0,0 @@
-{
- "id": "fab58d86-cf1e-4990-809a-58c029545a3a",
- "name": "Left",
- "description": "The left one",
- "timestamp": "2023-08-03T18:56:20",
- "model_id": "b4a38a62-16d9-4317-8cfc-510b37152776",
- "configuration": {
- "header":{
- "name": "Model",
- "schema": "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/petrinet_v0.5/petrinet/petrinet_schema.json",
- "description": "Model",
- "model_version": "0.1"
- },
- "model": {
- "states": [
- {
- "id": "susceptible_population",
- "name": "susceptible_population",
- "grounding": {
- "identifiers": {
- "ido": "0000514"
- },
- "modifiers": {}
- }
- },
- {
- "id": "infectious_population",
- "name": "infectious_population",
- "grounding": {
- "identifiers": {
- "ido": "0000513"
- },
- "modifiers": {}
- }
- },
- {
- "id": "recovered_population",
- "name": "recovered_population",
- "grounding": {
- "identifiers": {
- "ido": "0000592"
- },
- "modifiers": {}
- }
- },
- {
- "id": "hospitalized_population",
- "name": "hospitalized_population",
- "grounding": {
- "identifiers": {
- "ncit": "C25179"
- },
- "modifiers": {}
- }
- },
- {
- "id": "deceased_population",
- "name": "deceased_population",
- "grounding": {
- "identifiers": {
- "ncit": "C168970"
- },
- "modifiers": {}
- }
- }
- ],
- "transitions": [
- {
- "id": "t1",
- "input": [
- "infectious_population",
- "susceptible_population"
- ],
- "output": [
- "infectious_population",
- "infectious_population"
- ],
- "properties": {
- "name": "t1"
- }
- },
- {
- "id": "t2",
- "input": [
- "infectious_population"
- ],
- "output": [
- "recovered_population"
- ],
- "properties": {
- "name": "t2"
- }
- },
- {
- "id": "t3",
- "input": [
- "infectious_population"
- ],
- "output": [
- "hospitalized_population"
- ],
- "properties": {
- "name": "t3"
- }
- },
- {
- "id": "t4",
- "input": [
- "infectious_population"
- ],
- "output": [
- "deceased_population"
- ],
- "properties": {
- "name": "t4"
- }
- },
- {
- "id": "t5",
- "input": [
- "hospitalized_population"
- ],
- "output": [
- "recovered_population"
- ],
- "properties": {
- "name": "t5"
- }
- },
- {
- "id": "t6",
- "input": [
- "hospitalized_population"
- ],
- "output": [
- "deceased_population"
- ],
- "properties": {
- "name": "t6"
- }
- }
- ]
- },
- "semantics": {
- "ode": {
- "rates": [
- {
- "target": "t1",
- "expression": "beta*infectious_population*susceptible_population/total_population",
- "expression_mathml": "betainfectious_populationsusceptible_populationtotal_population"
- },
- {
- "target": "t2",
- "expression": "gamma*infectious_population*(-dnh - hosp + 1)",
- "expression_mathml": "gammainfectious_populationdnhhosp1"
- },
- {
- "target": "t3",
- "expression": "gamma*hosp*infectious_population",
- "expression_mathml": "gammahospinfectious_population"
- },
- {
- "target": "t4",
- "expression": "dnh*gamma*infectious_population",
- "expression_mathml": "dnhgammainfectious_population"
- },
- {
- "target": "t5",
- "expression": "hospitalized_population*(1 - dh)/los",
- "expression_mathml": "hospitalized_population1dhlos"
- },
- {
- "target": "t6",
- "expression": "dh*hospitalized_population/los",
- "expression_mathml": "dhhospitalized_populationlos"
- }
- ],
- "initials": [
- {
- "target": "susceptible_population",
- "expression": "99999.0000000000",
- "expression_mathml": "99999.0"
- },
- {
- "target": "infectious_population",
- "expression": "1.00000000000000",
- "expression_mathml": "1.0"
- },
- {
- "target": "recovered_population",
- "expression": "0.0",
- "expression_mathml": "0.0"
- },
- {
- "target": "hospitalized_population",
- "expression": "0.0",
- "expression_mathml": "0.0"
- },
- {
- "target": "deceased_population",
- "expression": "0.0",
- "expression_mathml": "0.0"
- }
- ],
- "parameters": [
- {
- "id": "beta",
- "value": 0.55,
- "distribution": {
- "type": "Uniform1",
- "parameters": {
- "minimum": 0.5,
- "maximum": 0.6
- }
- }
- },
- {
- "id": "total_population",
- "value": 100000
- },
- {
- "id": "gamma",
- "value": 0.2,
- "distribution": {
- "type": "Uniform1",
- "parameters": {
- "minimum": 0.1,
- "maximum": 0.2
- }
- }
- },
- {
- "id": "hosp",
- "value": 0.1
- },
- {
- "id": "dnh",
- "value": 0.001
- },
- {
- "id": "dh",
- "value": 0.1
- },
- {
- "id": "los",
- "value": 7
- }
- ],
- "observables": [
- {
- "id": "Infections",
- "name": "Infections",
- "expression": "infectious_population",
- "expression_mathml": "infectiouspopulation"
- }
- ],
- "time": {
- "id": "t"
- }
- }
- },
- "metadata": {
- "annotations": {
- "license": null,
- "authors": [],
- "references": [],
- "time_scale": null,
- "time_start": null,
- "time_end": null,
- "locations": [],
- "pathogens": [],
- "diseases": [],
- "hosts": [],
- "model_types": []
- }
- }
- },
- "amr_configuration": null,
- "calibrated": false,
- "calibration": null,
- "calibration_score": null
-}
\ No newline at end of file
diff --git a/tests/examples/ensemble-calibrate/input/left.json b/tests/examples/ensemble-calibrate/input/left.json
new file mode 100644
index 0000000..7a9ac52
--- /dev/null
+++ b/tests/examples/ensemble-calibrate/input/left.json
@@ -0,0 +1,430 @@
+{
+ "configuration": {
+ "id": "seirhd-with-reinfection01-petrinet",
+ "header": {
+ "name": "SEIRHD_with_reinfection01",
+ "schema": "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/petrinet_v0.5/petrinet/petrinet_schema.json",
+ "schema_name": "petrinet",
+ "description": "SEIRHD_with_reinfection01",
+ "model_version": "0.1",
+ "properties": {}
+ },
+ "model": {
+ "states": [
+ {
+ "id": "S",
+ "name": "S",
+ "grounding": {
+ "identifiers": {
+ "ido": "0000514"
+ },
+ "modifiers": {}
+ },
+ "units": {
+ "expression": "person",
+ "expression_mathml": "person"
+ }
+ },
+ {
+ "id": "I",
+ "name": "I",
+ "grounding": {
+ "identifiers": {
+ "ido": "0000511"
+ },
+ "modifiers": {}
+ },
+ "units": {
+ "expression": "person",
+ "expression_mathml": "person"
+ }
+ },
+ {
+ "id": "E",
+ "name": "E",
+ "grounding": {
+ "identifiers": {
+ "apollosv": "0000154"
+ },
+ "modifiers": {}
+ },
+ "units": {
+ "expression": "person",
+ "expression_mathml": "person"
+ }
+ },
+ {
+ "id": "R",
+ "name": "R",
+ "grounding": {
+ "identifiers": {
+ "ido": "0000592"
+ },
+ "modifiers": {}
+ },
+ "units": {
+ "expression": "person",
+ "expression_mathml": "person"
+ }
+ },
+ {
+ "id": "H",
+ "name": "H",
+ "grounding": {
+ "identifiers": {
+ "ido": "0000511"
+ },
+ "modifiers": {
+ "property": "ncit:C25179"
+ }
+ },
+ "units": {
+ "expression": "person",
+ "expression_mathml": "person"
+ }
+ },
+ {
+ "id": "D",
+ "name": "D",
+ "grounding": {
+ "identifiers": {
+ "ncit": "C28554"
+ },
+ "modifiers": {}
+ },
+ "units": {
+ "expression": "person",
+ "expression_mathml": "person"
+ }
+ }
+ ],
+ "transitions": [
+ {
+ "id": "t1",
+ "input": [
+ "I",
+ "S"
+ ],
+ "output": [
+ "I",
+ "E"
+ ],
+ "properties": {
+ "name": "t1"
+ }
+ },
+ {
+ "id": "t2",
+ "input": [
+ "E"
+ ],
+ "output": [
+ "I"
+ ],
+ "properties": {
+ "name": "t2"
+ }
+ },
+ {
+ "id": "t3",
+ "input": [
+ "I"
+ ],
+ "output": [
+ "R"
+ ],
+ "properties": {
+ "name": "t3"
+ }
+ },
+ {
+ "id": "t4",
+ "input": [
+ "I"
+ ],
+ "output": [
+ "H"
+ ],
+ "properties": {
+ "name": "t4"
+ }
+ },
+ {
+ "id": "t5",
+ "input": [
+ "H"
+ ],
+ "output": [
+ "R"
+ ],
+ "properties": {
+ "name": "t5"
+ }
+ },
+ {
+ "id": "t6",
+ "input": [
+ "H"
+ ],
+ "output": [
+ "D"
+ ],
+ "properties": {
+ "name": "t6"
+ }
+ },
+ {
+ "id": "t7",
+ "input": [
+ "R"
+ ],
+ "output": [
+ "S"
+ ],
+ "properties": {
+ "name": "t7"
+ }
+ }
+ ]
+ },
+ "semantics": {
+ "ode": {
+ "rates": [
+ {
+ "target": "t1",
+ "expression": "I*S*beta/total_population",
+ "expression_mathml": "ISbetatotal_population"
+ },
+ {
+ "target": "t2",
+ "expression": "E*delta",
+ "expression_mathml": "Edelta"
+ },
+ {
+ "target": "t3",
+ "expression": "I*gamma*(1 - hosp)",
+ "expression_mathml": "Igamma1hosp"
+ },
+ {
+ "target": "t4",
+ "expression": "I*gamma*hosp",
+ "expression_mathml": "Igammahosp"
+ },
+ {
+ "target": "t5",
+ "expression": "H*(1 - death_hosp)/los",
+ "expression_mathml": "H1death_hosplos"
+ },
+ {
+ "target": "t6",
+ "expression": "H*death_hosp/los",
+ "expression_mathml": "Hdeath_hosplos"
+ },
+ {
+ "target": "t7",
+ "expression": "R*roil",
+ "expression_mathml": "Rroil"
+ }
+ ],
+ "initials": [
+ {
+ "target": "S",
+ "expression": "total_population - I0",
+ "expression_mathml": "total_populationI0"
+ },
+ {
+ "target": "I",
+ "expression": "I0",
+ "expression_mathml": "I0"
+ },
+ {
+ "target": "E",
+ "expression": "40.0000000000000",
+ "expression_mathml": "40.0"
+ },
+ {
+ "target": "R",
+ "expression": "0.0",
+ "expression_mathml": "0.0"
+ },
+ {
+ "target": "H",
+ "expression": "0.0",
+ "expression_mathml": "0.0"
+ },
+ {
+ "target": "D",
+ "expression": "0.0",
+ "expression_mathml": "0.0"
+ }
+ ],
+ "parameters": [
+ {
+ "id": "beta",
+ "value": 0.4,
+ "distribution": {
+ "type": "Uniform1",
+ "parameters": {
+ "minimum": 0.05,
+ "maximum": 0.8
+ }
+ },
+ "units": {
+ "expression": "1/day",
+ "expression_mathml": "day-1"
+ }
+ },
+ {
+ "id": "total_population",
+ "value": 19340000.0,
+ "units": {
+ "expression": "person",
+ "expression_mathml": "person"
+ }
+ },
+ {
+ "id": "I0",
+ "value": 10.0,
+ "distribution": {
+ "type": "Uniform1",
+ "parameters": {
+ "minimum": 1.0,
+ "maximum": 15.0
+ }
+ },
+ "units": {
+ "expression": "person",
+ "expression_mathml": "person"
+ }
+ },
+ {
+ "id": "delta",
+ "value": 0.25,
+ "units": {
+ "expression": "1/day",
+ "expression_mathml": "day-1"
+ }
+ },
+ {
+ "id": "gamma",
+ "value": 0.2,
+ "distribution": {
+ "type": "Uniform1",
+ "parameters": {
+ "minimum": 0.1,
+ "maximum": 0.5
+ }
+ },
+ "units": {
+ "expression": "1/day",
+ "expression_mathml": "day-1"
+ }
+ },
+ {
+ "id": "hosp",
+ "value": 0.1,
+ "distribution": {
+ "type": "Uniform1",
+ "parameters": {
+ "minimum": 0.005,
+ "maximum": 0.2
+ }
+ },
+ "units": {
+ "expression": "1/day",
+ "expression_mathml": "day-1"
+ }
+ },
+ {
+ "id": "death_hosp",
+ "value": 0.07,
+ "distribution": {
+ "type": "Uniform1",
+ "parameters": {
+ "minimum": 0.01,
+ "maximum": 0.1
+ }
+ },
+ "units": {
+ "expression": "1/day",
+ "expression_mathml": "day-1"
+ }
+ },
+ {
+ "id": "los",
+ "value": 5.0,
+ "units": {
+ "expression": "day",
+ "expression_mathml": "day"
+ }
+ },
+ {
+ "id": "roil",
+ "value": 0.0027397260273972603,
+ "units": {
+ "expression": "1/day",
+ "expression_mathml": "day-1"
+ }
+ }
+ ],
+ "observables": [
+ {
+ "id": "infected",
+ "name": "infected",
+ "expression": "I",
+ "expression_mathml": "I"
+ },
+ {
+ "id": "exposed",
+ "name": "exposed",
+ "expression": "E",
+ "expression_mathml": "E"
+ },
+ {
+ "id": "hospitalized",
+ "name": "hospitalized",
+ "expression": "H",
+ "expression_mathml": "H"
+ },
+ {
+ "id": "dead",
+ "name": "dead",
+ "expression": "D",
+ "expression_mathml": "D"
+ }
+ ],
+ "time": {
+ "id": "t",
+ "units": {
+ "expression": "day",
+ "expression_mathml": "day"
+ }
+ }
+ }
+ },
+ "metadata": {
+ "annotations": {
+ "license": null,
+ "authors": [],
+ "references": [],
+ "time_scale": null,
+ "time_start": null,
+ "time_end": null,
+ "locations": [],
+ "pathogens": [],
+ "diseases": [],
+ "hosts": [],
+ "model_types": []
+ }
+ }
+ },
+ "id": "left",
+ "name": "left",
+ "description": "no changes made",
+ "timestamp": "2023-08-03T18:00:40",
+ "model_id": "b4a38a62-16d9-4317-8cfc-510b37152776",
+ "amr_configuration": null,
+ "calibrated": false,
+ "calibration": null,
+ "calibration_score": null
+}
\ No newline at end of file
diff --git a/tests/examples/ensemble-calibrate/input/request.json b/tests/examples/ensemble-calibrate/input/request.json
index 983be20..1c7d1a8 100644
--- a/tests/examples/ensemble-calibrate/input/request.json
+++ b/tests/examples/ensemble-calibrate/input/request.json
@@ -2,27 +2,30 @@
"engine": "ciemss",
"model_configs": [
{
- "id": "398d2a33-1cbe-44cf-9b66-a6976b947809",
+ "id": "right",
"weight": 0.5,
- "solution_mappings": {"Infected": "Cases", "Hospitalizations": "hospitalized_population"}
+ "solution_mappings": {"Infected": "I", "Hospitalizations": "H"}
},
{
- "id": "fab58d86-cf1e-4990-809a-58c029545a3a",
+ "id": "left",
"weight": 0.5,
- "solution_mappings": {"Infected": "Infections", "Hospitalizations": "hospitalized_population"}
+ "solution_mappings": {"Infected": "I", "Hospitalizations": "H"}
}
],
"dataset": {
"id": "cc52e71f-c744-4883-9412-0858d8455754",
- "filename": "ensemble.csv"
+ "filename": "ensemble.csv",
+ "mappings": {
+ "tstep": "Timestamp",
+ "Infected": "Infected",
+ "Hospitalizations": "Hospitalizations"
+ }
},
"timespan": {
"start": 0,
"end": 5
},
"extra": {
- "num_samples": 200,
- "total_population": 1000,
- "num_iterations": 8
+ "num_iterations": 2
}
}
\ No newline at end of file
diff --git a/tests/examples/ensemble-calibrate/input/right.json b/tests/examples/ensemble-calibrate/input/right.json
new file mode 100644
index 0000000..df4a115
--- /dev/null
+++ b/tests/examples/ensemble-calibrate/input/right.json
@@ -0,0 +1,444 @@
+{
+ "configuration": {
+ "id": "seirhd-npi-type1-petrinet",
+ "header": {
+ "name": "SEIRHD model NPI Type 1",
+ "schema": "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/petrinet_v0.6/petrinet/petrinet_schema.json",
+ "schema_name": "petrinet",
+ "description": "SEIRHD model NPI Type 1",
+ "model_version": "0.1"
+ },
+ "properties": {},
+ "model": {
+ "states": [
+ {
+ "id": "S",
+ "name": "S",
+ "grounding": {
+ "identifiers": {
+ "ido": "0000514"
+ },
+ "modifiers": {}
+ },
+ "units": {
+ "expression": "person",
+ "expression_mathml": "person"
+ }
+ },
+ {
+ "id": "I",
+ "name": "I",
+ "grounding": {
+ "identifiers": {
+ "ido": "0000511"
+ },
+ "modifiers": {}
+ },
+ "units": {
+ "expression": "person",
+ "expression_mathml": "person"
+ }
+ },
+ {
+ "id": "E",
+ "name": "E",
+ "grounding": {
+ "identifiers": {
+ "apollosv": "0000154"
+ },
+ "modifiers": {}
+ },
+ "units": {
+ "expression": "person",
+ "expression_mathml": "person"
+ }
+ },
+ {
+ "id": "R",
+ "name": "R",
+ "grounding": {
+ "identifiers": {
+ "ido": "0000592"
+ },
+ "modifiers": {}
+ },
+ "units": {
+ "expression": "person",
+ "expression_mathml": "person"
+ }
+ },
+ {
+ "id": "H",
+ "name": "H",
+ "grounding": {
+ "identifiers": {
+ "ido": "0000511"
+ },
+ "modifiers": {
+ "property": "ncit:C25179"
+ }
+ },
+ "units": {
+ "expression": "person",
+ "expression_mathml": "person"
+ }
+ },
+ {
+ "id": "D",
+ "name": "D",
+ "grounding": {
+ "identifiers": {
+ "ncit": "C28554"
+ },
+ "modifiers": {}
+ },
+ "units": {
+ "expression": "person",
+ "expression_mathml": "person"
+ }
+ }
+ ],
+ "transitions": [
+ {
+ "id": "t1",
+ "input": [
+ "I",
+ "S"
+ ],
+ "output": [
+ "I",
+ "E"
+ ],
+ "properties": {
+ "name": "t1"
+ }
+ },
+ {
+ "id": "t2",
+ "input": [
+ "E"
+ ],
+ "output": [
+ "I"
+ ],
+ "properties": {
+ "name": "t2"
+ }
+ },
+ {
+ "id": "t3",
+ "input": [
+ "I"
+ ],
+ "output": [
+ "R"
+ ],
+ "properties": {
+ "name": "t3"
+ }
+ },
+ {
+ "id": "t4",
+ "input": [
+ "I"
+ ],
+ "output": [
+ "H"
+ ],
+ "properties": {
+ "name": "t4"
+ }
+ },
+ {
+ "id": "t5",
+ "input": [
+ "H"
+ ],
+ "output": [
+ "R"
+ ],
+ "properties": {
+ "name": "t5"
+ }
+ },
+ {
+ "id": "t6",
+ "input": [
+ "H"
+ ],
+ "output": [
+ "D"
+ ],
+ "properties": {
+ "name": "t6"
+ }
+ }
+ ]
+ },
+ "semantics": {
+ "ode": {
+ "rates": [
+ {
+ "target": "t1",
+ "expression": "I*S*kappa*(beta_c + (-beta_c + beta_s)/(1 + exp(-k*(-t + t0))))/total_population",
+ "expression_mathml": "ISkappabeta_cbeta_cbeta_s1kt0ttotal_population"
+ },
+ {
+ "target": "t2",
+ "expression": "E*delta",
+ "expression_mathml": "Edelta"
+ },
+ {
+ "target": "t3",
+ "expression": "I*gamma*(1 - hosp)",
+ "expression_mathml": "Igamma1hosp"
+ },
+ {
+ "target": "t4",
+ "expression": "I*gamma*hosp",
+ "expression_mathml": "Igammahosp"
+ },
+ {
+ "target": "t5",
+ "expression": "H*(1 - death_hosp)/los",
+ "expression_mathml": "H1death_hosplos"
+ },
+ {
+ "target": "t6",
+ "expression": "H*death_hosp/los",
+ "expression_mathml": "Hdeath_hosplos"
+ }
+ ],
+ "initials": [
+ {
+ "target": "S",
+ "expression": "total_population - I0",
+ "expression_mathml": "total_populationI0"
+ },
+ {
+ "target": "I",
+ "expression": "I0",
+ "expression_mathml": "I0"
+ },
+ {
+ "target": "E",
+ "expression": "40.0000000000000",
+ "expression_mathml": "40.0"
+ },
+ {
+ "target": "R",
+ "expression": "0.0",
+ "expression_mathml": "0.0"
+ },
+ {
+ "target": "H",
+ "expression": "0.0",
+ "expression_mathml": "0.0"
+ },
+ {
+ "target": "D",
+ "expression": "0.0",
+ "expression_mathml": "0.0"
+ }
+ ],
+ "parameters": [
+ {
+ "id": "beta_c",
+ "value": 0.4,
+ "distribution": {
+ "type": "Uniform1",
+ "parameters": {
+ "minimum": 0.1,
+ "maximum": 0.8
+ }
+ },
+ "units": {
+ "expression": "1/day",
+ "expression_mathml": "day-1"
+ }
+ },
+ {
+ "id": "beta_s",
+ "value": 1.0,
+ "units": {
+ "expression": "1/day",
+ "expression_mathml": "day-1"
+ }
+ },
+ {
+ "id": "k",
+ "value": 5.0,
+ "units": {
+ "expression": "1",
+ "expression_mathml": "1"
+ }
+ },
+ {
+ "id": "kappa",
+ "value": 0.4,
+ "distribution": {
+ "type": "Uniform1",
+ "parameters": {
+ "minimum": 0.05,
+ "maximum": 0.8
+ }
+ },
+ "units": {
+ "expression": "1/day",
+ "expression_mathml": "day-1"
+ }
+ },
+ {
+ "id": "t0",
+ "value": 89.0,
+ "units": {
+ "expression": "day",
+ "expression_mathml": "day"
+ }
+ },
+ {
+ "id": "total_population",
+ "value": 19340000.0,
+ "units": {
+ "expression": "person",
+ "expression_mathml": "person"
+ }
+ },
+ {
+ "id": "I0",
+ "value": 10.0,
+ "distribution": {
+ "type": "Uniform1",
+ "parameters": {
+ "minimum": 1.0,
+ "maximum": 15.0
+ }
+ },
+ "units": {
+ "expression": "person",
+ "expression_mathml": "person"
+ }
+ },
+ {
+ "id": "delta",
+ "value": 0.25,
+ "units": {
+ "expression": "1/day",
+ "expression_mathml": "day-1"
+ }
+ },
+ {
+ "id": "gamma",
+ "value": 0.2,
+ "distribution": {
+ "type": "Uniform1",
+ "parameters": {
+ "minimum": 0.1,
+ "maximum": 0.5
+ }
+ },
+ "units": {
+ "expression": "1/day",
+ "expression_mathml": "day-1"
+ }
+ },
+ {
+ "id": "hosp",
+ "value": 0.1,
+ "distribution": {
+ "type": "Uniform1",
+ "parameters": {
+ "minimum": 0.005,
+ "maximum": 0.2
+ }
+ },
+ "units": {
+ "expression": "1/day",
+ "expression_mathml": "day-1"
+ }
+ },
+ {
+ "id": "death_hosp",
+ "value": 0.07,
+ "distribution": {
+ "type": "Uniform1",
+ "parameters": {
+ "minimum": 0.01,
+ "maximum": 0.1
+ }
+ },
+ "units": {
+ "expression": "1/day",
+ "expression_mathml": "day-1"
+ }
+ },
+ {
+ "id": "los",
+ "value": 5.0,
+ "units": {
+ "expression": "day",
+ "expression_mathml": "day"
+ }
+ }
+ ],
+ "observables": [
+ {
+ "id": "infected",
+ "name": "infected",
+ "expression": "I",
+ "expression_mathml": "I"
+ },
+ {
+ "id": "exposed",
+ "name": "exposed",
+ "expression": "E",
+ "expression_mathml": "E"
+ },
+ {
+ "id": "hospitalized",
+ "name": "hospitalized",
+ "expression": "H",
+ "expression_mathml": "H"
+ },
+ {
+ "id": "dead",
+ "name": "dead",
+ "expression": "D",
+ "expression_mathml": "D"
+ }
+ ],
+ "time": {
+ "id": "t",
+ "units": {
+ "expression": "day",
+ "expression_mathml": "day"
+ }
+ }
+ }
+ },
+ "metadata": {
+ "annotations": {
+ "license": null,
+ "authors": [],
+ "references": [],
+ "time_scale": null,
+ "time_start": null,
+ "time_end": null,
+ "locations": [],
+ "pathogens": [],
+ "diseases": [],
+ "hosts": [],
+ "model_types": []
+ }
+ }
+ },
+ "id": "right",
+ "name": "right",
+ "description": "no changes made",
+ "timestamp": "2023-08-03T18:00:40",
+ "model_id": "b4a38a62-16d9-4317-8cfc-510b37152776",
+ "amr_configuration": null,
+ "calibrated": false,
+ "calibration": null,
+ "calibration_score": null
+}
\ No newline at end of file
diff --git a/tests/integration/test_ensemble_calibrate.py b/tests/integration/test_ensemble_calibrate.py
new file mode 100644
index 0000000..ff835f1
--- /dev/null
+++ b/tests/integration/test_ensemble_calibrate.py
@@ -0,0 +1,73 @@
+import json
+
+import pytest
+
+from service.settings import settings
+
+TDS_URL = settings.TDS_URL
+
+
+@pytest.mark.example_dir("ensemble-calibrate")
+def test_ensemble_calibrate_example(
+ example_context, client, worker, file_storage, file_check, requests_mock
+):
+ job_id = "9036f8a8-7e55-4e77-aeec-e8d4ca120d67"
+
+ request = example_context["request"]
+ config_ids = [
+ config["id"] for config in example_context["request"]["model_configs"]
+ ]
+ for config_id in config_ids:
+ model = json.loads(example_context["fetch"](config_id + ".json"))
+ requests_mock.get(f"{TDS_URL}/model-configurations/{config_id}", json=model)
+
+ dataset_id = example_context["request"]["dataset"]["id"]
+ filename = example_context["request"]["dataset"]["filename"]
+ dataset = example_context["fetch"](filename, True)
+ dataset_loc = {"method": "GET", "url": dataset}
+ requests_mock.get(
+ f"{TDS_URL}/datasets/{dataset_id}/download-url?filename={filename}",
+ json=dataset_loc,
+ )
+ requests_mock.get("http://dataset", text=dataset)
+
+ requests_mock.post(f"{TDS_URL}/simulations", json={"id": str(job_id)})
+
+ response = client.post(
+ "/ensemble-calibrate",
+ json=request,
+ headers={"Content-Type": "application/json"},
+ )
+ simulation_id = response.json()["simulation_id"]
+ response = client.get(
+ f"/status/{simulation_id}",
+ )
+ status = response.json()["status"]
+ assert status == "queued"
+
+ tds_sim = example_context["tds_simulation"]
+ tds_sim["id"] = simulation_id
+
+ requests_mock.get(f"{TDS_URL}/simulations/{simulation_id}", json=tds_sim)
+ requests_mock.put(
+ f"{TDS_URL}/simulations/{simulation_id}", json={"status": "success"}
+ )
+
+ worker.work(burst=True)
+
+ response = client.get(
+ f"/status/{simulation_id}",
+ )
+ status = response.json()["status"]
+ result = file_storage("result.csv")
+ viz = file_storage("visualization.json")
+ # eval = file_storage("eval.csv") # NOTE: Do we want to check this
+
+ # Checks
+ assert status == "complete"
+
+ assert result is not None
+ assert file_check("csv", result)
+
+ assert viz is not None
+ assert file_check("json", viz)
diff --git a/tests/test_conversions.py b/tests/test_conversions.py
index eee3372..9a197cc 100644
--- a/tests/test_conversions.py
+++ b/tests/test_conversions.py
@@ -7,6 +7,7 @@
sample,
calibrate,
ensemble_sample,
+ ensemble_calibrate,
optimize,
) # noqa: F401
@@ -14,6 +15,7 @@
Simulate,
Calibrate,
EnsembleSimulate,
+ EnsembleCalibrate,
Optimize,
)
from service.settings import settings
@@ -98,6 +100,36 @@ def test_example_conversion(self, example_context, requests_mock):
is_satisfactory(kwargs, ensemble_sample)
+class TestEnsembleCalibrate:
+ @pytest.mark.example_dir("ensemble-calibrate")
+ def test_example_conversion(self, example_context, requests_mock):
+ job_id = example_context["tds_simulation"]["id"]
+
+ config_ids = [
+ config["id"] for config in example_context["request"]["model_configs"]
+ ]
+ for config_id in config_ids:
+ model = json.loads(example_context["fetch"](config_id + ".json"))
+ requests_mock.get(f"{TDS_URL}/model-configurations/{config_id}", json=model)
+
+ dataset_id = example_context["request"]["dataset"]["id"]
+ filename = example_context["request"]["dataset"]["filename"]
+ dataset = example_context["fetch"](filename, True)
+ dataset_loc = {"method": "GET", "url": dataset}
+ requests_mock.get(
+ f"{TDS_URL}/datasets/{dataset_id}/download-url?filename={filename}",
+ json=dataset_loc,
+ )
+ requests_mock.get("http://dataset", text=dataset)
+
+ ### Act and Assert
+
+ operation_request = EnsembleCalibrate(**example_context["request"])
+ kwargs = operation_request.gen_pyciemss_args(job_id)
+
+ is_satisfactory(kwargs, ensemble_calibrate)
+
+
class TestOptimize:
@pytest.mark.example_dir("optimize")
def test_example_conversion(self, example_context, requests_mock):
From 4fe4fa7d0241a142a4c001efefdb9bf6ca54cbc3 Mon Sep 17 00:00:00 2001
From: Five Grant <5@fivegrant.com>
Date: Tue, 19 Mar 2024 09:59:41 -0500
Subject: [PATCH 2/3] Swap out dataset for test
---
tests/examples/ensemble-calibrate/input/ensemble.csv | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/tests/examples/ensemble-calibrate/input/ensemble.csv b/tests/examples/ensemble-calibrate/input/ensemble.csv
index a57f834..80f9147 100644
--- a/tests/examples/ensemble-calibrate/input/ensemble.csv
+++ b/tests/examples/ensemble-calibrate/input/ensemble.csv
@@ -1,6 +1,4 @@
tstep,Infected,Hospitalizations
-0.0,1.0,1.9617580129804857e-12
-1.0,0.9480640292167664,0.01719011925160885
-2.0,1.066840648651123,0.031743988394737244
-3.0,1.310190200805664,0.04602520912885666
-4.0,1.66862154006958,0.06171417236328125
+1.1,15.0,0.1
+2.2,18.0,1.0
+3.3,20.0,2.2
From 628904f2d3c171cbbae305c69a8a2c80ecc4f9c7 Mon Sep 17 00:00:00 2001
From: Five Grant <5@fivegrant.com>
Date: Tue, 19 Mar 2024 11:10:56 -0500
Subject: [PATCH 3/3] Handle inferred parameters from ensemble-calibrate
---
service/models/operations/ensemble_simulate.py | 17 ++++++++++++++---
tests/integration/test_ensemble_calibrate.py | 11 +++++------
2 files changed, 19 insertions(+), 9 deletions(-)
diff --git a/service/models/operations/ensemble_simulate.py b/service/models/operations/ensemble_simulate.py
index bb68157..adc5fb6 100644
--- a/service/models/operations/ensemble_simulate.py
+++ b/service/models/operations/ensemble_simulate.py
@@ -1,19 +1,24 @@
from __future__ import annotations
-from typing import ClassVar, List
+from typing import ClassVar, List, Optional
from pydantic import BaseModel, Field, Extra
import torch # TODO: Do not use Torch in PyCIEMSS Library interface
from models.base import OperationRequest, Timespan, ModelConfig
from models.converters import convert_to_solution_mapping
-from utils.tds import fetch_model
+from utils.tds import fetch_model, fetch_inferred_parameters
class EnsembleSimulateExtra(BaseModel):
num_samples: int = Field(
100, description="number of samples for a CIEMSS simulation", example=100
)
+ inferred_parameters: Optional[str] = Field(
+ None,
+ description="id from a previous calibration",
+ example=None,
+ )
class EnsembleSimulate(OperationRequest):
@@ -38,6 +43,11 @@ def gen_pyciemss_args(self, job_id):
]
amr_paths = [fetch_model(config.id, job_id) for config in self.model_configs]
+ extra_options = self.extra.dict()
+ inferred_parameters = fetch_inferred_parameters(
+ extra_options.pop("inferred_parameters"), job_id
+ )
+
return {
"model_paths_or_jsons": amr_paths,
"solution_mappings": solution_mappings,
@@ -45,8 +55,9 @@ def gen_pyciemss_args(self, job_id):
"end_time": self.timespan.end,
"logging_step_size": self.step_size,
"dirichlet_alpha": weights,
+ "inferred_parameters": inferred_parameters,
# "visual_options": True,
- **self.extra.dict(),
+ **extra_options,
}
class Config:
diff --git a/tests/integration/test_ensemble_calibrate.py b/tests/integration/test_ensemble_calibrate.py
index ff835f1..26c3d13 100644
--- a/tests/integration/test_ensemble_calibrate.py
+++ b/tests/integration/test_ensemble_calibrate.py
@@ -59,15 +59,14 @@ def test_ensemble_calibrate_example(
f"/status/{simulation_id}",
)
status = response.json()["status"]
- result = file_storage("result.csv")
- viz = file_storage("visualization.json")
+ params = file_storage("parameters.pickle")
+ # viz = file_storage("visualization.json")
# eval = file_storage("eval.csv") # NOTE: Do we want to check this
# Checks
assert status == "complete"
- assert result is not None
- assert file_check("csv", result)
+ assert params is not None
- assert viz is not None
- assert file_check("json", viz)
+ # assert viz is not None
+ # assert file_check("json", viz)