From 4e7445d344685d2672ccc22a1857b1c7e0e6dfc8 Mon Sep 17 00:00:00 2001 From: Tom Szendrey Date: Wed, 22 May 2024 11:36:12 -0400 Subject: [PATCH 1/3] rename for clarity + add fixed interventions --- service/models/operations/optimize.py | 39 ++++++++++++++++----------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/service/models/operations/optimize.py b/service/models/operations/optimize.py index 485bfc6..6ba875f 100644 --- a/service/models/operations/optimize.py +++ b/service/models/operations/optimize.py @@ -6,14 +6,14 @@ import numpy as np import torch from pydantic import BaseModel, Field, Extra -from models.base import OperationRequest, Timespan +from models.base import OperationRequest, Timespan, InterventionObject from pyciemss.integration_utils.intervention_builder import ( param_value_objective, start_time_objective, ) from pyciemss.ouu.qoi import obs_nday_average_qoi, obs_max_qoi - +from models.converters import convert_to_static_interventions from utils.tds import fetch_model, fetch_inferred_parameters @@ -80,7 +80,7 @@ class Optimize(OperationRequest): pyciemss_lib_function: ClassVar[str] = "optimize" model_config_id: str = Field(..., example="ba8da8d4-047d-11ee-be56") timespan: Timespan = Timespan(start=0, end=90) - interventions: InterventionObjective + policy_interventions: InterventionObjective step_size: float = 1.0 qoi: QOI risk_bound: float @@ -90,29 +90,37 @@ class Optimize(OperationRequest): None, description="optional extra system specific arguments for advanced use cases", ) + fixed_static_parameter_interventions: List[InterventionObject] = Field( + default_factory=list, + description="The interventions provided via the model config which are not being optimized", + ) def gen_pyciemss_args(self, job_id): # Get model from TDS amr_path = fetch_model(self.model_config_id, job_id) - - intervention_type = self.interventions.selection + fixed_static_parameter_interventions = convert_to_static_interventions( + self.fixed_static_parameter_interventions + ) + intervention_type = self.policy_interventions.selection if intervention_type == "param_value": - assert self.interventions.start_time is not None - start_time = [torch.tensor(time) for time in self.interventions.start_time] - param_value = [None] * len(self.interventions.param_names) + assert self.policy_interventions.start_time is not None + start_time = [ + torch.tensor(time) for time in self.policy_interventions.start_time + ] + param_value = [None] * len(self.policy_interventions.param_names) - interventions = param_value_objective( + policy_interventions = param_value_objective( start_time=start_time, - param_name=self.interventions.param_names, + param_name=self.policy_interventions.param_names, param_value=param_value, ) else: - assert self.interventions.param_values is not None + assert self.policy_interventions.param_values is not None param_value = [ - torch.tensor(value) for value in self.interventions.param_values + torch.tensor(value) for value in self.policy_interventions.param_values ] - interventions = start_time_objective( - param_name=self.interventions.param_names, + policy_interventions = start_time_objective( + param_name=self.policy_interventions.param_names, param_value=param_value, ) @@ -133,7 +141,8 @@ def gen_pyciemss_args(self, job_id): "risk_bound": self.risk_bound, "initial_guess_interventions": self.initial_guess_interventions, "bounds_interventions": self.bounds_interventions, - "static_parameter_interventions": interventions, + "static_parameter_interventions": policy_interventions, + "fixed_static_parameter_interventions": fixed_static_parameter_interventions, "inferred_parameters": inferred_parameters, "n_samples_ouu": n_samples_ouu, **extra_options, From 10df752b33be7679f5bf17db7cbe2f9d474f1352 Mon Sep 17 00:00:00 2001 From: Tom Szendrey Date: Tue, 28 May 2024 11:01:51 -0400 Subject: [PATCH 2/3] updating test output --- tests/examples/optimize/output/tds_simulation.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/examples/optimize/output/tds_simulation.json b/tests/examples/optimize/output/tds_simulation.json index c81894d..1aa82b9 100644 --- a/tests/examples/optimize/output/tds_simulation.json +++ b/tests/examples/optimize/output/tds_simulation.json @@ -14,7 +14,7 @@ "start": 0, "end": 90 }, - "interventions": [], + "policy_interventions": [], "extra": { "num_samples": 100 } From b2943886c8d4b4815ffdd9c5f579b1d34ae20a25 Mon Sep 17 00:00:00 2001 From: Tom Szendrey Date: Tue, 28 May 2024 13:26:41 -0400 Subject: [PATCH 3/3] correcting tests this time i believe --- tests/examples/optimize/input/request.json | 2 +- tests/examples/optimize/output/tds_simulation.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/examples/optimize/input/request.json b/tests/examples/optimize/input/request.json index 7a996ca..e8b2a8b 100644 --- a/tests/examples/optimize/input/request.json +++ b/tests/examples/optimize/input/request.json @@ -2,7 +2,7 @@ "engine": "ciemss", "user_id": "not_provided", "model_config_id": "sidarthe", - "interventions": { + "policy_interventions": { "selection": "param_value", "start_time": [2], "param_names": ["beta"], diff --git a/tests/examples/optimize/output/tds_simulation.json b/tests/examples/optimize/output/tds_simulation.json index 1aa82b9..c81894d 100644 --- a/tests/examples/optimize/output/tds_simulation.json +++ b/tests/examples/optimize/output/tds_simulation.json @@ -14,7 +14,7 @@ "start": 0, "end": 90 }, - "policy_interventions": [], + "interventions": [], "extra": { "num_samples": 100 }