Skip to content

Commit

Permalink
Optimize new plumbing (#101)
Browse files Browse the repository at this point in the history
* updating to new interventions and allowing fixed

* minor update to test case

* minor update to testcase

* adding initial_guess to test

* adding missing objective_function_option

* correcting converting indexing
  • Loading branch information
Tom-Szendrey authored Jul 15, 2024
1 parent 54bdd88 commit ade0def
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 32 deletions.
27 changes: 26 additions & 1 deletion service/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import ClassVar, Dict
from typing import ClassVar, Dict, Optional
from pydantic import BaseModel, Field


Expand Down Expand Up @@ -42,6 +42,31 @@ class InterventionSelection(BaseModel):
name: str


class HMIStaticIntervention(BaseModel):
timestep: float
value: float


class HMIDynamicIntervention(BaseModel):
parameter: str
threshold: float
value: float
is_greater_than: bool


class HMIInterventionPolicy(BaseModel):
model_id: Optional[str] = Field(default="")
Interventions: list[HMIIntervention]


class HMIIntervention(BaseModel):
name: str
applied_to: str
type: str
static_interventions: Optional[list[HMIStaticIntervention]] = Field(default=None)
dynamic_interventions: Optional[list[HMIDynamicIntervention]] = Field(default=None)


class OperationRequest(BaseModel):
pyciemss_lib_function: ClassVar[str] = ""
engine: str = Field("ciemss", example="ciemss")
Expand Down
15 changes: 12 additions & 3 deletions service/models/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,23 @@
import torch
from utils.tds import fetch_interventions
from typing import Dict
from models.base import HMIIntervention


def fetch_and_convert_static_interventions(policy_intervention_id, job_id):
static_interventions: Dict[torch.Tensor, Dict[str, any]] = defaultdict(dict)
if not (policy_intervention_id):
return static_interventions
return defaultdict(dict)
policy_intervention = fetch_interventions(policy_intervention_id, job_id)
for inter in policy_intervention["interventions"]:
interventionList = policy_intervention["interventions"]
return convert_static_interventions(interventionList)


# Used to convert from HMI Intervention Policy -> pyciemss static interventions.
def convert_static_interventions(interventions: list[HMIIntervention]):
if not (interventions):
return defaultdict(dict)
static_interventions: Dict[torch.Tensor, Dict[str, any]] = defaultdict(dict)
for inter in interventions:
for static_inter in inter["static_interventions"]:
time = torch.tensor(float(static_inter["timestep"]))
parameter_name = inter["applied_to"]
Expand Down
59 changes: 34 additions & 25 deletions service/models/operations/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import numpy as np
import torch
from pydantic import BaseModel, Field, Extra
from models.base import OperationRequest, Timespan
from models.base import OperationRequest, Timespan, HMIIntervention
from pyciemss.integration_utils.intervention_builder import (
param_value_objective,
start_time_objective,
)

from pyciemss.ouu.qoi import obs_nday_average_qoi, obs_max_qoi
from models.converters import fetch_and_convert_static_interventions
from models.converters import convert_static_interventions
from utils.tds import fetch_model, fetch_inferred_parameters


Expand All @@ -37,22 +37,26 @@ def gen_call(self):
return qoi_map[self.method]


def objfun(x, is_minimized):
if is_minimized:
def objfun(x, initial_guess, objective_function_option):
if objective_function_option == "lower_bound":
return np.sum(np.abs(x))
else:
if objective_function_option == "upper_bound":
return -np.sum(np.abs(x))
if objective_function_option == "initial_guess":
return np.sum(np.abs(x - initial_guess))


class InterventionObjective(BaseModel):
selection: str = Field(
intervention_type: str = Field(
"param_value",
description="The intervention objective to use",
example="param_value",
)
param_names: list[str]
param_values: Optional[list[Optional[float]]] = None
start_time: Optional[list[float]] = None
objective_function_option: Optional[list[str]] = None
initial_guess: Optional[list[float]] = None


class OptimizeExtra(BaseModel):
Expand All @@ -70,7 +74,6 @@ class OptimizeExtra(BaseModel):
)
maxiter: int = 5
maxfeval: int = 25
is_minimized: bool = True
alpha: float = 0.95
solver_method: str = "dopri5"
solver_options: Dict[str, Any] = {}
Expand All @@ -80,7 +83,10 @@ class Optimize(OperationRequest):
pyciemss_lib_function: ClassVar[str] = "optimize"
model_config_id: str = Field(..., example="ba8da8d4-047d-11ee-be56")
timespan: Timespan = Timespan(start=0, end=90)
policy_interventions: InterventionObjective
optimize_interventions: InterventionObjective # These are the interventions to be optimized.
fixed_static_parameter_interventions: list[HMIIntervention] = Field(
None
) # Theses are interventions provided that will not be optimized
step_size: float = 1.0
qoi: QOI
risk_bound: float
Expand All @@ -90,35 +96,35 @@ class Optimize(OperationRequest):
None,
description="optional extra system specific arguments for advanced use cases",
)
policy_intervention_id: str = Field(None, example="ba8da8d4-047d-11ee-be56")

def gen_pyciemss_args(self, job_id):
# Get model from TDS
amr_path = fetch_model(self.model_config_id, job_id)
static_interventions = fetch_and_convert_static_interventions(
self.policy_intervention_id, job_id
fixed_static_parameter_interventions = convert_static_interventions(
self.fixed_static_parameter_interventions
)

intervention_type = self.policy_interventions.selection
intervention_type = self.optimize_interventions.intervention_type
if intervention_type == "param_value":
assert self.policy_interventions.start_time is not None
assert self.optimize_interventions.start_time is not None
start_time = [
torch.tensor(time) for time in self.policy_interventions.start_time
torch.tensor(time) for time in self.optimize_interventions.start_time
]
param_value = [None] * len(self.policy_interventions.param_names)
param_value = [None] * len(self.optimize_interventions.param_names)

policy_interventions = param_value_objective(
optimize_interventions = param_value_objective(
start_time=start_time,
param_name=self.policy_interventions.param_names,
param_name=self.optimize_interventions.param_names,
param_value=param_value,
)
else:
assert self.policy_interventions.param_values is not None
assert self.optimize_interventions.param_values is not None
param_value = [
torch.tensor(value) for value in self.policy_interventions.param_values
torch.tensor(value)
for value in self.optimize_interventions.param_values
]
policy_interventions = start_time_objective(
param_name=self.policy_interventions.param_names,
optimize_interventions = start_time_objective(
param_name=self.optimize_interventions.param_names,
param_value=param_value,
)

Expand All @@ -127,20 +133,23 @@ def gen_pyciemss_args(self, job_id):
extra_options.pop("inferred_parameters"), job_id
)
n_samples_ouu = extra_options.pop("num_samples")
is_minimized = extra_options.pop("is_minimized")

return {
"model_path_or_json": amr_path,
"logging_step_size": self.step_size,
"start_time": self.timespan.start,
"end_time": self.timespan.end,
"objfun": lambda x: objfun(x, is_minimized),
"objfun": lambda x: objfun(
x,
self.optimize_interventions.initial_guess[0],
self.optimize_interventions.objective_function_option[0],
),
"qoi": self.qoi.gen_call(),
"risk_bound": self.risk_bound,
"initial_guess_interventions": self.initial_guess_interventions,
"bounds_interventions": self.bounds_interventions,
"static_parameter_interventions": policy_interventions,
"fixed_static_parameter_interventions": static_interventions,
"static_parameter_interventions": optimize_interventions,
"fixed_static_parameter_interventions": fixed_static_parameter_interventions,
"inferred_parameters": inferred_parameters,
"n_samples_ouu": n_samples_ouu,
**extra_options,
Expand Down
8 changes: 5 additions & 3 deletions tests/examples/optimize/input/request.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
"engine": "ciemss",
"user_id": "not_provided",
"model_config_id": "sidarthe",
"policy_interventions": {
"selection": "param_value",
"optimize_interventions": {
"intervention_type": "param_value",
"objective_function_option": ["lower_bound"],
"start_time": [2],
"param_names": ["beta"],
"param_values": [0.02]
"param_values": [0.02],
"initial_guess": [0]
},
"timespan": {
"start": 0,
Expand Down

0 comments on commit ade0def

Please sign in to comment.