Skip to content

Commit

Permalink
Allow Interventions (#16)
Browse files Browse the repository at this point in the history
* interventions

* Create empty interventions by default

* Add logging of interventions

* Strip interventions from calibrate

* Fix typo

* Fix logging message

* Reinclude docker network

---------

Co-authored-by: Five Grant <5@fivegrant.com>
  • Loading branch information
marshHawk4 and fivegrant authored Jul 17, 2023
1 parent a71f9b8 commit c9fadcc
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
5 changes: 5 additions & 0 deletions api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,17 @@ class Dataset(BaseModel):
example={'postive_tests': 'infected'},
)

class InterventionObject(BaseModel):
timestep: float
name: str
value: float

class SimulatePostRequest(BaseModel):
engine: Engine = Field(..., example="ciemss")
username: str = Field("not_provided", example="not_provided")
model_config_id: str = Field(..., example="ba8da8d4-047d-11ee-be56")
timespan: Timespan
interventions: List[InterventionObject] = Field(default_factory=list, example=[{"timestep":1,"name":"beta","value":.4}])
extra: SimulateExtra = Field(
None,
description="optional extra system specific arguments for advanced use cases",
Expand Down
21 changes: 16 additions & 5 deletions api/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
from fastapi import FastAPI, Response, status
from fastapi.middleware.cors import CORSMiddleware

Expand All @@ -13,6 +14,9 @@
)


logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)

def build_api(*args) -> FastAPI:

api = FastAPI(
Expand Down Expand Up @@ -53,7 +57,7 @@ def get_status(simulation_id: str) -> StatusSimulationIdGetResponse:
from utils import fetch_job_status

status = fetch_job_status(simulation_id)
print(status)
logging.info(status)
if not isinstance(status, str):
return status

Expand All @@ -72,6 +76,10 @@ def simulate_model(body: SimulatePostRequest) -> JobResponse:
username = body.username
start = body.timespan.start
end = body.timespan.end
interventions = [
(intervention.timestep, intervention.name, intervention.value) for intervention in body.interventions
]


operation_name = "operations.simulate"
options = {
Expand All @@ -81,11 +89,14 @@ def simulate_model(body: SimulatePostRequest) -> JobResponse:
"start": start,
"end": end,
"extra": body.extra.dict(),
"visual_options": True
"visual_options": True,
"interventions": interventions
}

resp = create_job(operation_name=operation_name, options=options)

if len(interventions) > 0:
logging.info(f"{resp['id']} used interventions: {interventions}")
response = {"simulation_id": resp["id"]}

return response
Expand All @@ -99,7 +110,7 @@ def calibrate_model(body: CalibratePostRequest) -> JobResponse:
from utils import create_job

# Parse request body
print(body)
logging.info(body)
engine = str(body.engine).lower()
username = body.username
model_config_id = body.model_config_id
Expand All @@ -108,7 +119,6 @@ def calibrate_model(body: CalibratePostRequest) -> JobResponse:
end = body.timespan.end
extra = body.extra.dict()


operation_name = "operations.calibrate_then_simulate"
options = {
"engine": engine,
Expand All @@ -118,7 +128,7 @@ def calibrate_model(body: CalibratePostRequest) -> JobResponse:
"end": end,
"dataset": dataset.dict(),
"extra": extra,
"visual_options": True
"visual_options": True,
}

resp = create_job(operation_name=operation_name, options=options)
Expand All @@ -139,3 +149,4 @@ def create_ensemble(body: EnsemblePostRequest) -> JobResponse:
)



6 changes: 3 additions & 3 deletions api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def create_job(operation_name: str, options: Optional[Dict[Any, Any]] = None):
job = q.fetch_job(job_id)

if STANDALONE:
print(f"OPTIONS: {options}")
logging.info(f"OPTIONS: {options}")
ex_payload = {
"engine": "ciemss",
"model_config_id": options.get("model_config_id"),
Expand All @@ -73,9 +73,9 @@ def create_job(operation_name: str, options: Optional[Dict[Any, Any]] = None):
"engine": "ciemss",
"workflow_id": job_id,
}
print(payload)
logging.info(payload)
sys.stdout.flush()
print(requests.put(post_url, json=json.loads(json.dumps(payload))).content)
logging.info(requests.put(post_url, json=json.loads(json.dumps(payload))).content)

if job and force_restart:
job.cleanup(ttl=0) # Cleanup/remove data immediately
Expand Down

0 comments on commit c9fadcc

Please sign in to comment.