From 05f8e46ecaac4020e9dccb33a97bd03261a12d0a Mon Sep 17 00:00:00 2001 From: Tom Szendrey Date: Wed, 28 Aug 2024 16:27:48 -0400 Subject: [PATCH] Update interventions to have distinct param and state. (#115) * updating intervention. distinct param and state. * Please retest this Shawn * adding comment for David to be happy --- service/models/converters.py | 37 +++++++++++++++++--------- service/models/operations/calibrate.py | 22 +++++++++++---- service/models/operations/optimize.py | 11 +++++--- service/models/operations/simulate.py | 20 ++++++++------ 4 files changed, 61 insertions(+), 29 deletions(-) diff --git a/service/models/converters.py b/service/models/converters.py index eeacdac..583ab91 100644 --- a/service/models/converters.py +++ b/service/models/converters.py @@ -9,7 +9,7 @@ def fetch_and_convert_static_interventions(policy_intervention_id, job_id): if not (policy_intervention_id): - return defaultdict(dict) + return defaultdict(dict), defaultdict(dict) policy_intervention = fetch_interventions(policy_intervention_id, job_id) interventionList: list[HMIIntervention] = [] for inter in policy_intervention["interventions"]: @@ -26,7 +26,7 @@ def fetch_and_convert_static_interventions(policy_intervention_id, job_id): def fetch_and_convert_dynamic_interventions(policy_intervention_id, job_id): if not (policy_intervention_id): - return defaultdict(dict) + return defaultdict(dict), defaultdict(dict) policy_intervention = fetch_interventions(policy_intervention_id, job_id) interventionList: list[HMIIntervention] = [] for inter in policy_intervention["interventions"]: @@ -44,15 +44,19 @@ def fetch_and_convert_dynamic_interventions(policy_intervention_id, job_id): # Used to convert from HMI Intervention Policy -> pyciemss static interventions. def convert_static_interventions(interventions: list[HMIIntervention]): if not (interventions): - return defaultdict(dict) - static_interventions: Dict[torch.Tensor, Dict[str, any]] = defaultdict(dict) + return defaultdict(dict), defaultdict(dict) + static_param_interventions: Dict[torch.Tensor, Dict[str, any]] = defaultdict(dict) + static_state_interventions: Dict[torch.Tensor, Dict[str, any]] = defaultdict(dict) for inter in interventions: for static_inter in inter.static_interventions: time = torch.tensor(float(static_inter.timestep)) parameter_name = inter.applied_to value = torch.tensor(float(static_inter.value)) - static_interventions[time][parameter_name] = value - return static_interventions + if inter.type == "parameter": + static_param_interventions[time][parameter_name] = value + if inter.type == "state": + static_state_interventions[time][parameter_name] = value + return static_param_interventions, static_state_interventions # Define the threshold for when the intervention should be applied. @@ -68,11 +72,15 @@ def var_threshold(time, state): # Used to convert from HMI Intervention Policy -> pyciemss dynamic interventions. def convert_dynamic_interventions(interventions: list[HMIIntervention]): if not (interventions): - return defaultdict(dict) + return defaultdict(dict), defaultdict(dict) dynamic_parameter_interventions: Dict[ Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor], Dict[str, any], ] = defaultdict(dict) + dynamic_state_interventions: Dict[ + Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor], + Dict[str, any], + ] = defaultdict(dict) for inter in interventions: for dynamic_inter in inter.dynamic_interventions: parameter_name = inter.applied_to @@ -81,11 +89,16 @@ def convert_dynamic_interventions(interventions: list[HMIIntervention]): threshold_func = make_var_threshold( dynamic_inter.parameter, threshold_value ) - dynamic_parameter_interventions[threshold_func].update( - {parameter_name: to_value} - ) - - return dynamic_parameter_interventions + if inter.type == "parameter": + dynamic_parameter_interventions[threshold_func].update( + {parameter_name: to_value} + ) + if inter.type == "state": + dynamic_state_interventions[threshold_func].update( + {parameter_name: to_value} + ) + + return dynamic_parameter_interventions, dynamic_state_interventions def convert_to_solution_mapping(config): diff --git a/service/models/operations/calibrate.py b/service/models/operations/calibrate.py index c49fe90..c07f834 100644 --- a/service/models/operations/calibrate.py +++ b/service/models/operations/calibrate.py @@ -9,7 +9,10 @@ from models.base import Dataset, OperationRequest, Timespan -from models.converters import fetch_and_convert_static_interventions +from models.converters import ( + fetch_and_convert_static_interventions, + fetch_and_convert_dynamic_interventions, +) from utils.rabbitmq import gen_rabbitmq_hook from utils.tds import fetch_dataset, fetch_model @@ -62,9 +65,15 @@ def gen_pyciemss_args(self, job_id): dataset_path = fetch_dataset(self.dataset.dict(), job_id) - static_interventions = fetch_and_convert_static_interventions( - self.policy_intervention_id, job_id - ) + ( + static_param_interventions, + static_state_interventions, + ) = fetch_and_convert_static_interventions(self.policy_intervention_id, job_id) + + ( + dynamic_param_interventions, + dynamic_state_interventions, + ) = fetch_and_convert_dynamic_interventions(self.policy_intervention_id, job_id) # TODO: Test RabbitMQ try: @@ -95,7 +104,10 @@ def hook(progress, _loss): # TODO: Is this intentionally missing from `calibrate`? # "end_time": self.timespan.end, "data_path": dataset_path, - "static_parameter_interventions": static_interventions, + "static_parameter_interventions": static_param_interventions, + "static_state_interventions": static_state_interventions, + "dynamic_parameter_interventions": dynamic_param_interventions, + "dynamic_state_interventions": dynamic_state_interventions, "progress_hook": hook, "solver_method": solver_method, "solver_options": solver_options, diff --git a/service/models/operations/optimize.py b/service/models/operations/optimize.py index faf7bce..9ce386a 100644 --- a/service/models/operations/optimize.py +++ b/service/models/operations/optimize.py @@ -90,7 +90,7 @@ class Optimize(OperationRequest): model_config_id: str = Field(..., example="ba8da8d4-047d-11ee-be56") timespan: Timespan = Timespan(start=0, end=90) optimize_interventions: InterventionObjective # These are the interventions to be optimized. - fixed_static_parameter_interventions: list[HMIIntervention] = Field( + fixed_interventions: list[HMIIntervention] = Field( None ) # Theses are interventions provided that will not be optimized logging_step_size: float = 1.0 @@ -105,9 +105,10 @@ class Optimize(OperationRequest): def gen_pyciemss_args(self, job_id): # Get model from TDS amr_path = fetch_model(self.model_config_id, job_id) - fixed_static_parameter_interventions = convert_static_interventions( - self.fixed_static_parameter_interventions - ) + ( + fixed_static_parameter_interventions, + fixed_static_state_interventions, + ) = convert_static_interventions(self.fixed_interventions) intervention_type = self.optimize_interventions.intervention_type if intervention_type == "param_value": @@ -166,6 +167,8 @@ def gen_pyciemss_args(self, job_id): "bounds_interventions": self.bounds_interventions, "static_parameter_interventions": optimize_interventions, "fixed_static_parameter_interventions": fixed_static_parameter_interventions, + # https://github.com/DARPA-ASKEM/terarium/issues/4612 + # "fixed_static_state_interventions": fixed_static_state_interventions, "inferred_parameters": inferred_parameters, "n_samples_ouu": n_samples_ouu, "solver_method": solver_method, diff --git a/service/models/operations/simulate.py b/service/models/operations/simulate.py index 5e3ca60..6698a1d 100644 --- a/service/models/operations/simulate.py +++ b/service/models/operations/simulate.py @@ -49,13 +49,15 @@ def gen_pyciemss_args(self, job_id): # Get model from TDS amr_path = fetch_model(self.model_config_id, job_id) - static_interventions = fetch_and_convert_static_interventions( - self.policy_intervention_id, job_id - ) + ( + static_param_interventions, + static_state_interventions, + ) = fetch_and_convert_static_interventions(self.policy_intervention_id, job_id) - dynamic_interventions = fetch_and_convert_dynamic_interventions( - self.policy_intervention_id, job_id - ) + ( + dynamic_param_interventions, + dynamic_state_interventions, + ) = fetch_and_convert_dynamic_interventions(self.policy_intervention_id, job_id) extra_options = self.extra.dict() inferred_parameters = fetch_inferred_parameters( @@ -75,8 +77,10 @@ def gen_pyciemss_args(self, job_id): "logging_step_size": self.logging_step_size, "start_time": self.timespan.start, "end_time": self.timespan.end, - "static_parameter_interventions": static_interventions, - "dynamic_parameter_interventions": dynamic_interventions, + "static_parameter_interventions": static_param_interventions, + "static_state_interventions": static_state_interventions, + "dynamic_parameter_interventions": dynamic_param_interventions, + "dynamic_state_interventions": dynamic_state_interventions, "inferred_parameters": inferred_parameters, "solver_method": solver_method, "solver_options": solver_options,