Skip to content

Commit

Permalink
lots of testing todo
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom-Szendrey committed Jul 22, 2024
1 parent 374050c commit 47072d4
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 2 deletions.
48 changes: 47 additions & 1 deletion service/models/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# TODO: Do not use Torch in PyCIEMSS Library interface
import torch
from utils.tds import fetch_interventions
from typing import Dict
from typing import Dict, Callable
from models.base import HMIIntervention


Expand All @@ -24,6 +24,23 @@ def fetch_and_convert_static_interventions(policy_intervention_id, job_id):
return convert_static_interventions(interventionList)


def fetch_and_convert_dynamic_interventions(policy_intervention_id, job_id):
if not (policy_intervention_id):
return defaultdict(dict)
policy_intervention = fetch_interventions(policy_intervention_id, job_id)
interventionList: list[HMIIntervention] = []
for inter in policy_intervention["interventions"]:
intervention = HMIIntervention(
name=inter["name"],
applied_to=inter["applied_to"],
type=inter["type"],
static_interventions=inter["static_interventions"],
dynamic_interventions=inter["dynamic_interventions"],
)
interventionList.append(intervention)
return convert_dynamic_interventions(interventionList)


# Used to convert from HMI Intervention Policy -> pyciemss static interventions.
def convert_static_interventions(interventions: list[HMIIntervention]):
if not (interventions):
Expand All @@ -38,6 +55,35 @@ def convert_static_interventions(interventions: list[HMIIntervention]):
return static_interventions


def make_var_threshold(var: str, threshold: torch.Tensor):
def var_threshold(time, state):
return state[var] - threshold

return var_threshold


def convert_dynamic_interventions(interventions: list[HMIIntervention]):
if not (interventions):
return defaultdict(dict)
dynamic_parameter_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, any],
] = defaultdict(dict)
for inter in interventions:
for dynamic_inter in inter.dynamic_interventions:
parameter_name = inter.applied_to
threshold_value = torch.tensor(float(dynamic_inter.threshold))
to_value = torch.tensor(float(dynamic_inter.value))
threshold_func = make_var_threshold(
dynamic_inter.parameter, threshold_value
)
dynamic_parameter_interventions[threshold_func].update(
{parameter_name: to_value}
)

return dynamic_parameter_interventions


def convert_to_solution_mapping(config):
individual_to_ensemble = {
individual_state: ensemble_state
Expand Down
10 changes: 9 additions & 1 deletion service/models/operations/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@


from models.base import OperationRequest, Timespan
from models.converters import fetch_and_convert_static_interventions
from models.converters import (
fetch_and_convert_static_interventions,
fetch_and_convert_dynamic_interventions,
)
from utils.tds import fetch_model, fetch_inferred_parameters


Expand Down Expand Up @@ -39,6 +42,10 @@ def gen_pyciemss_args(self, job_id):
self.policy_intervention_id, job_id
)

dynamic_interventions = fetch_and_convert_dynamic_interventions(
self.policy_intervention_id, job_id
)

extra_options = self.extra.dict()
inferred_parameters = fetch_inferred_parameters(
extra_options.pop("inferred_parameters"), job_id
Expand All @@ -50,6 +57,7 @@ def gen_pyciemss_args(self, job_id):
"start_time": self.timespan.start,
"end_time": self.timespan.end,
"static_parameter_interventions": static_interventions,
"dynamic_parameter_interventions": dynamic_interventions,
"inferred_parameters": inferred_parameters,
**extra_options,
}
Expand Down

0 comments on commit 47072d4

Please sign in to comment.