diff --git a/service/models/operations/optimize.py b/service/models/operations/optimize.py index 485bfc6..6ba875f 100644 --- a/service/models/operations/optimize.py +++ b/service/models/operations/optimize.py @@ -6,14 +6,14 @@ import numpy as np import torch from pydantic import BaseModel, Field, Extra -from models.base import OperationRequest, Timespan +from models.base import OperationRequest, Timespan, InterventionObject from pyciemss.integration_utils.intervention_builder import ( param_value_objective, start_time_objective, ) from pyciemss.ouu.qoi import obs_nday_average_qoi, obs_max_qoi - +from models.converters import convert_to_static_interventions from utils.tds import fetch_model, fetch_inferred_parameters @@ -80,7 +80,7 @@ class Optimize(OperationRequest): pyciemss_lib_function: ClassVar[str] = "optimize" model_config_id: str = Field(..., example="ba8da8d4-047d-11ee-be56") timespan: Timespan = Timespan(start=0, end=90) - interventions: InterventionObjective + policy_interventions: InterventionObjective step_size: float = 1.0 qoi: QOI risk_bound: float @@ -90,29 +90,37 @@ class Optimize(OperationRequest): None, description="optional extra system specific arguments for advanced use cases", ) + fixed_static_parameter_interventions: List[InterventionObject] = Field( + default_factory=list, + description="The interventions provided via the model config which are not being optimized", + ) def gen_pyciemss_args(self, job_id): # Get model from TDS amr_path = fetch_model(self.model_config_id, job_id) - - intervention_type = self.interventions.selection + fixed_static_parameter_interventions = convert_to_static_interventions( + self.fixed_static_parameter_interventions + ) + intervention_type = self.policy_interventions.selection if intervention_type == "param_value": - assert self.interventions.start_time is not None - start_time = [torch.tensor(time) for time in self.interventions.start_time] - param_value = [None] * len(self.interventions.param_names) + assert self.policy_interventions.start_time is not None + start_time = [ + torch.tensor(time) for time in self.policy_interventions.start_time + ] + param_value = [None] * len(self.policy_interventions.param_names) - interventions = param_value_objective( + policy_interventions = param_value_objective( start_time=start_time, - param_name=self.interventions.param_names, + param_name=self.policy_interventions.param_names, param_value=param_value, ) else: - assert self.interventions.param_values is not None + assert self.policy_interventions.param_values is not None param_value = [ - torch.tensor(value) for value in self.interventions.param_values + torch.tensor(value) for value in self.policy_interventions.param_values ] - interventions = start_time_objective( - param_name=self.interventions.param_names, + policy_interventions = start_time_objective( + param_name=self.policy_interventions.param_names, param_value=param_value, ) @@ -133,7 +141,8 @@ def gen_pyciemss_args(self, job_id): "risk_bound": self.risk_bound, "initial_guess_interventions": self.initial_guess_interventions, "bounds_interventions": self.bounds_interventions, - "static_parameter_interventions": interventions, + "static_parameter_interventions": policy_interventions, + "fixed_static_parameter_interventions": fixed_static_parameter_interventions, "inferred_parameters": inferred_parameters, "n_samples_ouu": n_samples_ouu, **extra_options, diff --git a/tests/examples/optimize/input/request.json b/tests/examples/optimize/input/request.json index 7a996ca..e8b2a8b 100644 --- a/tests/examples/optimize/input/request.json +++ b/tests/examples/optimize/input/request.json @@ -2,7 +2,7 @@ "engine": "ciemss", "user_id": "not_provided", "model_config_id": "sidarthe", - "interventions": { + "policy_interventions": { "selection": "param_value", "start_time": [2], "param_names": ["beta"],