Skip to content

Commit

Permalink
adding OptimizeIntervention as a type
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom-Szendrey committed Feb 27, 2024
1 parent 1781e04 commit 5016e23
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 7 deletions.
5 changes: 5 additions & 0 deletions service/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ class InterventionObject(BaseModel):
value: float


class OptimizeInterventionObject(BaseModel):
timestep: float
name: str


class InterventionSelection(BaseModel):
timestep: float
name: str
Expand Down
7 changes: 7 additions & 0 deletions service/models/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ def convert_to_static_interventions(interventions):
return static_interventions


def convert_optimize_to_static_interventions(interventions):
static_interventions = defaultdict(dict)
for i in interventions:
static_interventions[i.timestep] = torch.tensor([i.name])
return static_interventions


def convert_to_solution_mapping(config):
individual_to_ensemble = {
individual_state: ensemble_state
Expand Down
10 changes: 5 additions & 5 deletions service/models/operations/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import numpy as np
import torch
from pydantic import BaseModel, Field, Extra
from models.base import OperationRequest, Timespan, InterventionObject
from models.converters import convert_to_static_interventions
from models.base import OperationRequest, Timespan, OptimizeInterventionObject
from models.converters import convert_optimize_to_static_interventions
from utils.tds import fetch_model, fetch_inferred_parameters


Expand Down Expand Up @@ -57,8 +57,8 @@ class Optimize(OperationRequest):
pyciemss_lib_function: ClassVar[str] = "optimize"
model_config_id: str = Field(..., example="ba8da8d4-047d-11ee-be56")
timespan: Timespan = Timespan(start=0, end=90)
interventions: List[InterventionObject] = Field(
default_factory=list, example=[{"timestep": 1, "name": "beta", "value": 0.4}]
interventions: List[OptimizeInterventionObject] = Field(
default_factory=list, example=[{"timestep": 1, "name": "beta"}]
)
step_size: float = 1.0
qoi: List[str] # QOIMethod
Expand All @@ -74,7 +74,7 @@ def gen_pyciemss_args(self, job_id):
# Get model from TDS
amr_path = fetch_model(self.model_config_id, job_id)

interventions = convert_to_static_interventions(self.interventions)
interventions = convert_optimize_to_static_interventions(self.interventions)

extra_options = self.extra.dict()
inferred_parameters = fetch_inferred_parameters(
Expand Down
3 changes: 1 addition & 2 deletions tests/examples/optimize/input/request.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
"interventions": [
{
"timestep": 1.0,
"name": "beta",
"value": 0.4
"name": "beta"
}
],
"timespan": {
Expand Down

0 comments on commit 5016e23

Please sign in to comment.