From 47072d4c2544f469167ccad2abd6c8b629e7dd8f Mon Sep 17 00:00:00 2001 From: Tom Szendrey Date: Mon, 22 Jul 2024 16:48:53 -0400 Subject: [PATCH 1/2] lots of testing todo --- service/models/converters.py | 48 ++++++++++++++++++++++++++- service/models/operations/simulate.py | 10 +++++- 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/service/models/converters.py b/service/models/converters.py index 0616494..580f960 100644 --- a/service/models/converters.py +++ b/service/models/converters.py @@ -3,7 +3,7 @@ # TODO: Do not use Torch in PyCIEMSS Library interface import torch from utils.tds import fetch_interventions -from typing import Dict +from typing import Dict, Callable from models.base import HMIIntervention @@ -24,6 +24,23 @@ def fetch_and_convert_static_interventions(policy_intervention_id, job_id): return convert_static_interventions(interventionList) +def fetch_and_convert_dynamic_interventions(policy_intervention_id, job_id): + if not (policy_intervention_id): + return defaultdict(dict) + policy_intervention = fetch_interventions(policy_intervention_id, job_id) + interventionList: list[HMIIntervention] = [] + for inter in policy_intervention["interventions"]: + intervention = HMIIntervention( + name=inter["name"], + applied_to=inter["applied_to"], + type=inter["type"], + static_interventions=inter["static_interventions"], + dynamic_interventions=inter["dynamic_interventions"], + ) + interventionList.append(intervention) + return convert_dynamic_interventions(interventionList) + + # Used to convert from HMI Intervention Policy -> pyciemss static interventions. def convert_static_interventions(interventions: list[HMIIntervention]): if not (interventions): @@ -38,6 +55,35 @@ def convert_static_interventions(interventions: list[HMIIntervention]): return static_interventions +def make_var_threshold(var: str, threshold: torch.Tensor): + def var_threshold(time, state): + return state[var] - threshold + + return var_threshold + + +def convert_dynamic_interventions(interventions: list[HMIIntervention]): + if not (interventions): + return defaultdict(dict) + dynamic_parameter_interventions: Dict[ + Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor], + Dict[str, any], + ] = defaultdict(dict) + for inter in interventions: + for dynamic_inter in inter.dynamic_interventions: + parameter_name = inter.applied_to + threshold_value = torch.tensor(float(dynamic_inter.threshold)) + to_value = torch.tensor(float(dynamic_inter.value)) + threshold_func = make_var_threshold( + dynamic_inter.parameter, threshold_value + ) + dynamic_parameter_interventions[threshold_func].update( + {parameter_name: to_value} + ) + + return dynamic_parameter_interventions + + def convert_to_solution_mapping(config): individual_to_ensemble = { individual_state: ensemble_state diff --git a/service/models/operations/simulate.py b/service/models/operations/simulate.py index a162e9d..2aeeebb 100644 --- a/service/models/operations/simulate.py +++ b/service/models/operations/simulate.py @@ -5,7 +5,10 @@ from models.base import OperationRequest, Timespan -from models.converters import fetch_and_convert_static_interventions +from models.converters import ( + fetch_and_convert_static_interventions, + fetch_and_convert_dynamic_interventions, +) from utils.tds import fetch_model, fetch_inferred_parameters @@ -39,6 +42,10 @@ def gen_pyciemss_args(self, job_id): self.policy_intervention_id, job_id ) + dynamic_interventions = fetch_and_convert_dynamic_interventions( + self.policy_intervention_id, job_id + ) + extra_options = self.extra.dict() inferred_parameters = fetch_inferred_parameters( extra_options.pop("inferred_parameters"), job_id @@ -50,6 +57,7 @@ def gen_pyciemss_args(self, job_id): "start_time": self.timespan.start, "end_time": self.timespan.end, "static_parameter_interventions": static_interventions, + "dynamic_parameter_interventions": dynamic_interventions, "inferred_parameters": inferred_parameters, **extra_options, } From 586ad9f9c9ec34e6007b68634650437e35b7a867 Mon Sep 17 00:00:00 2001 From: Tom Szendrey Date: Tue, 23 Jul 2024 11:42:46 -0400 Subject: [PATCH 2/2] adding minor comments --- service/models/converters.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/service/models/converters.py b/service/models/converters.py index 580f960..eeacdac 100644 --- a/service/models/converters.py +++ b/service/models/converters.py @@ -55,6 +55,9 @@ def convert_static_interventions(interventions: list[HMIIntervention]): return static_interventions +# Define the threshold for when the intervention should be applied. +# Can support further functions options in the future +# https://github.com/ciemss/pyciemss/blob/main/docs/source/interfaces.ipynb def make_var_threshold(var: str, threshold: torch.Tensor): def var_threshold(time, state): return state[var] - threshold @@ -62,6 +65,7 @@ def var_threshold(time, state): return var_threshold +# Used to convert from HMI Intervention Policy -> pyciemss dynamic interventions. def convert_dynamic_interventions(interventions: list[HMIIntervention]): if not (interventions): return defaultdict(dict)