Skip to content

Commit

Permalink
Update interventions to have distinct param and state. (#115)
Browse files Browse the repository at this point in the history
* updating intervention. distinct param and state.

* Please retest this Shawn

* adding comment for David to be happy
  • Loading branch information
Tom-Szendrey authored Aug 28, 2024
1 parent 1839dc9 commit 05f8e46
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 29 deletions.
37 changes: 25 additions & 12 deletions service/models/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def fetch_and_convert_static_interventions(policy_intervention_id, job_id):
if not (policy_intervention_id):
return defaultdict(dict)
return defaultdict(dict), defaultdict(dict)
policy_intervention = fetch_interventions(policy_intervention_id, job_id)
interventionList: list[HMIIntervention] = []
for inter in policy_intervention["interventions"]:
Expand All @@ -26,7 +26,7 @@ def fetch_and_convert_static_interventions(policy_intervention_id, job_id):

def fetch_and_convert_dynamic_interventions(policy_intervention_id, job_id):
if not (policy_intervention_id):
return defaultdict(dict)
return defaultdict(dict), defaultdict(dict)
policy_intervention = fetch_interventions(policy_intervention_id, job_id)
interventionList: list[HMIIntervention] = []
for inter in policy_intervention["interventions"]:
Expand All @@ -44,15 +44,19 @@ def fetch_and_convert_dynamic_interventions(policy_intervention_id, job_id):
# Used to convert from HMI Intervention Policy -> pyciemss static interventions.
def convert_static_interventions(interventions: list[HMIIntervention]):
if not (interventions):
return defaultdict(dict)
static_interventions: Dict[torch.Tensor, Dict[str, any]] = defaultdict(dict)
return defaultdict(dict), defaultdict(dict)
static_param_interventions: Dict[torch.Tensor, Dict[str, any]] = defaultdict(dict)
static_state_interventions: Dict[torch.Tensor, Dict[str, any]] = defaultdict(dict)
for inter in interventions:
for static_inter in inter.static_interventions:
time = torch.tensor(float(static_inter.timestep))
parameter_name = inter.applied_to
value = torch.tensor(float(static_inter.value))
static_interventions[time][parameter_name] = value
return static_interventions
if inter.type == "parameter":
static_param_interventions[time][parameter_name] = value
if inter.type == "state":
static_state_interventions[time][parameter_name] = value
return static_param_interventions, static_state_interventions


# Define the threshold for when the intervention should be applied.
Expand All @@ -68,11 +72,15 @@ def var_threshold(time, state):
# Used to convert from HMI Intervention Policy -> pyciemss dynamic interventions.
def convert_dynamic_interventions(interventions: list[HMIIntervention]):
if not (interventions):
return defaultdict(dict)
return defaultdict(dict), defaultdict(dict)
dynamic_parameter_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, any],
] = defaultdict(dict)
dynamic_state_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, any],
] = defaultdict(dict)
for inter in interventions:
for dynamic_inter in inter.dynamic_interventions:
parameter_name = inter.applied_to
Expand All @@ -81,11 +89,16 @@ def convert_dynamic_interventions(interventions: list[HMIIntervention]):
threshold_func = make_var_threshold(
dynamic_inter.parameter, threshold_value
)
dynamic_parameter_interventions[threshold_func].update(
{parameter_name: to_value}
)

return dynamic_parameter_interventions
if inter.type == "parameter":
dynamic_parameter_interventions[threshold_func].update(
{parameter_name: to_value}
)
if inter.type == "state":
dynamic_state_interventions[threshold_func].update(
{parameter_name: to_value}
)

return dynamic_parameter_interventions, dynamic_state_interventions


def convert_to_solution_mapping(config):
Expand Down
22 changes: 17 additions & 5 deletions service/models/operations/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@


from models.base import Dataset, OperationRequest, Timespan
from models.converters import fetch_and_convert_static_interventions
from models.converters import (
fetch_and_convert_static_interventions,
fetch_and_convert_dynamic_interventions,
)
from utils.rabbitmq import gen_rabbitmq_hook
from utils.tds import fetch_dataset, fetch_model

Expand Down Expand Up @@ -62,9 +65,15 @@ def gen_pyciemss_args(self, job_id):

dataset_path = fetch_dataset(self.dataset.dict(), job_id)

static_interventions = fetch_and_convert_static_interventions(
self.policy_intervention_id, job_id
)
(
static_param_interventions,
static_state_interventions,
) = fetch_and_convert_static_interventions(self.policy_intervention_id, job_id)

(
dynamic_param_interventions,
dynamic_state_interventions,
) = fetch_and_convert_dynamic_interventions(self.policy_intervention_id, job_id)

# TODO: Test RabbitMQ
try:
Expand Down Expand Up @@ -95,7 +104,10 @@ def hook(progress, _loss):
# TODO: Is this intentionally missing from `calibrate`?
# "end_time": self.timespan.end,
"data_path": dataset_path,
"static_parameter_interventions": static_interventions,
"static_parameter_interventions": static_param_interventions,
"static_state_interventions": static_state_interventions,
"dynamic_parameter_interventions": dynamic_param_interventions,
"dynamic_state_interventions": dynamic_state_interventions,
"progress_hook": hook,
"solver_method": solver_method,
"solver_options": solver_options,
Expand Down
11 changes: 7 additions & 4 deletions service/models/operations/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class Optimize(OperationRequest):
model_config_id: str = Field(..., example="ba8da8d4-047d-11ee-be56")
timespan: Timespan = Timespan(start=0, end=90)
optimize_interventions: InterventionObjective # These are the interventions to be optimized.
fixed_static_parameter_interventions: list[HMIIntervention] = Field(
fixed_interventions: list[HMIIntervention] = Field(
None
) # Theses are interventions provided that will not be optimized
logging_step_size: float = 1.0
Expand All @@ -105,9 +105,10 @@ class Optimize(OperationRequest):
def gen_pyciemss_args(self, job_id):
# Get model from TDS
amr_path = fetch_model(self.model_config_id, job_id)
fixed_static_parameter_interventions = convert_static_interventions(
self.fixed_static_parameter_interventions
)
(
fixed_static_parameter_interventions,
fixed_static_state_interventions,
) = convert_static_interventions(self.fixed_interventions)

intervention_type = self.optimize_interventions.intervention_type
if intervention_type == "param_value":
Expand Down Expand Up @@ -166,6 +167,8 @@ def gen_pyciemss_args(self, job_id):
"bounds_interventions": self.bounds_interventions,
"static_parameter_interventions": optimize_interventions,
"fixed_static_parameter_interventions": fixed_static_parameter_interventions,
# https://github.com/DARPA-ASKEM/terarium/issues/4612
# "fixed_static_state_interventions": fixed_static_state_interventions,
"inferred_parameters": inferred_parameters,
"n_samples_ouu": n_samples_ouu,
"solver_method": solver_method,
Expand Down
20 changes: 12 additions & 8 deletions service/models/operations/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,15 @@ def gen_pyciemss_args(self, job_id):
# Get model from TDS
amr_path = fetch_model(self.model_config_id, job_id)

static_interventions = fetch_and_convert_static_interventions(
self.policy_intervention_id, job_id
)
(
static_param_interventions,
static_state_interventions,
) = fetch_and_convert_static_interventions(self.policy_intervention_id, job_id)

dynamic_interventions = fetch_and_convert_dynamic_interventions(
self.policy_intervention_id, job_id
)
(
dynamic_param_interventions,
dynamic_state_interventions,
) = fetch_and_convert_dynamic_interventions(self.policy_intervention_id, job_id)

extra_options = self.extra.dict()
inferred_parameters = fetch_inferred_parameters(
Expand All @@ -75,8 +77,10 @@ def gen_pyciemss_args(self, job_id):
"logging_step_size": self.logging_step_size,
"start_time": self.timespan.start,
"end_time": self.timespan.end,
"static_parameter_interventions": static_interventions,
"dynamic_parameter_interventions": dynamic_interventions,
"static_parameter_interventions": static_param_interventions,
"static_state_interventions": static_state_interventions,
"dynamic_parameter_interventions": dynamic_param_interventions,
"dynamic_state_interventions": dynamic_state_interventions,
"inferred_parameters": inferred_parameters,
"solver_method": solver_method,
"solver_options": solver_options,
Expand Down

0 comments on commit 05f8e46

Please sign in to comment.