Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize new plumbing #101

Merged
merged 6 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion service/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import ClassVar, Dict
from typing import ClassVar, Dict, Optional
from pydantic import BaseModel, Field


Expand Down Expand Up @@ -42,6 +42,31 @@ class InterventionSelection(BaseModel):
name: str


class HMIStaticIntervention(BaseModel):
timestep: float
value: float


class HMIDynamicIntervention(BaseModel):
parameter: str
threshold: float
value: float
is_greater_than: bool


class HMIInterventionPolicy(BaseModel):
model_id: Optional[str] = Field(default="")
Interventions: list[HMIIntervention]


class HMIIntervention(BaseModel):
name: str
applied_to: str
type: str
static_interventions: Optional[list[HMIStaticIntervention]] = Field(default=None)
dynamic_interventions: Optional[list[HMIDynamicIntervention]] = Field(default=None)


class OperationRequest(BaseModel):
pyciemss_lib_function: ClassVar[str] = ""
engine: str = Field("ciemss", example="ciemss")
Expand Down
15 changes: 12 additions & 3 deletions service/models/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,23 @@
import torch
from utils.tds import fetch_interventions
from typing import Dict
from models.base import HMIIntervention


def fetch_and_convert_static_interventions(policy_intervention_id, job_id):
static_interventions: Dict[torch.Tensor, Dict[str, any]] = defaultdict(dict)
if not (policy_intervention_id):
return static_interventions
return defaultdict(dict)
policy_intervention = fetch_interventions(policy_intervention_id, job_id)
for inter in policy_intervention["interventions"]:
interventionList = policy_intervention["interventions"]
return convert_static_interventions(interventionList)


# Used to convert from HMI Intervention Policy -> pyciemss static interventions.
def convert_static_interventions(interventions: list[HMIIntervention]):
if not (interventions):
return defaultdict(dict)
static_interventions: Dict[torch.Tensor, Dict[str, any]] = defaultdict(dict)
for inter in interventions:
for static_inter in inter["static_interventions"]:
time = torch.tensor(float(static_inter["timestep"]))
parameter_name = inter["applied_to"]
Expand Down
59 changes: 34 additions & 25 deletions service/models/operations/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import numpy as np
import torch
from pydantic import BaseModel, Field, Extra
from models.base import OperationRequest, Timespan
from models.base import OperationRequest, Timespan, HMIIntervention
from pyciemss.integration_utils.intervention_builder import (
param_value_objective,
start_time_objective,
)

from pyciemss.ouu.qoi import obs_nday_average_qoi, obs_max_qoi
from models.converters import fetch_and_convert_static_interventions
from models.converters import convert_static_interventions
from utils.tds import fetch_model, fetch_inferred_parameters


Expand All @@ -37,22 +37,26 @@ def gen_call(self):
return qoi_map[self.method]


def objfun(x, is_minimized):
if is_minimized:
def objfun(x, initial_guess, objective_function_option):
if objective_function_option == "lower_bound":
return np.sum(np.abs(x))
else:
if objective_function_option == "upper_bound":
return -np.sum(np.abs(x))
if objective_function_option == "initial_guess":
return np.sum(np.abs(x - initial_guess))


class InterventionObjective(BaseModel):
selection: str = Field(
intervention_type: str = Field(
"param_value",
description="The intervention objective to use",
example="param_value",
)
param_names: list[str]
param_values: Optional[list[Optional[float]]] = None
start_time: Optional[list[float]] = None
objective_function_option: Optional[list[str]] = None
initial_guess: Optional[list[float]] = None


class OptimizeExtra(BaseModel):
Expand All @@ -70,7 +74,6 @@ class OptimizeExtra(BaseModel):
)
maxiter: int = 5
maxfeval: int = 25
is_minimized: bool = True
alpha: float = 0.95
solver_method: str = "dopri5"
solver_options: Dict[str, Any] = {}
Expand All @@ -80,7 +83,10 @@ class Optimize(OperationRequest):
pyciemss_lib_function: ClassVar[str] = "optimize"
model_config_id: str = Field(..., example="ba8da8d4-047d-11ee-be56")
timespan: Timespan = Timespan(start=0, end=90)
policy_interventions: InterventionObjective
optimize_interventions: InterventionObjective # These are the interventions to be optimized.
fixed_static_parameter_interventions: list[HMIIntervention] = Field(
None
) # Theses are interventions provided that will not be optimized
step_size: float = 1.0
qoi: QOI
risk_bound: float
Expand All @@ -90,35 +96,35 @@ class Optimize(OperationRequest):
None,
description="optional extra system specific arguments for advanced use cases",
)
policy_intervention_id: str = Field(None, example="ba8da8d4-047d-11ee-be56")

def gen_pyciemss_args(self, job_id):
# Get model from TDS
amr_path = fetch_model(self.model_config_id, job_id)
static_interventions = fetch_and_convert_static_interventions(
self.policy_intervention_id, job_id
fixed_static_parameter_interventions = convert_static_interventions(
self.fixed_static_parameter_interventions
)

intervention_type = self.policy_interventions.selection
intervention_type = self.optimize_interventions.intervention_type
if intervention_type == "param_value":
assert self.policy_interventions.start_time is not None
assert self.optimize_interventions.start_time is not None
start_time = [
torch.tensor(time) for time in self.policy_interventions.start_time
torch.tensor(time) for time in self.optimize_interventions.start_time
]
param_value = [None] * len(self.policy_interventions.param_names)
param_value = [None] * len(self.optimize_interventions.param_names)

policy_interventions = param_value_objective(
optimize_interventions = param_value_objective(
start_time=start_time,
param_name=self.policy_interventions.param_names,
param_name=self.optimize_interventions.param_names,
param_value=param_value,
)
else:
assert self.policy_interventions.param_values is not None
assert self.optimize_interventions.param_values is not None
param_value = [
torch.tensor(value) for value in self.policy_interventions.param_values
torch.tensor(value)
for value in self.optimize_interventions.param_values
]
policy_interventions = start_time_objective(
param_name=self.policy_interventions.param_names,
optimize_interventions = start_time_objective(
param_name=self.optimize_interventions.param_names,
param_value=param_value,
)

Expand All @@ -127,20 +133,23 @@ def gen_pyciemss_args(self, job_id):
extra_options.pop("inferred_parameters"), job_id
)
n_samples_ouu = extra_options.pop("num_samples")
is_minimized = extra_options.pop("is_minimized")

return {
"model_path_or_json": amr_path,
"logging_step_size": self.step_size,
"start_time": self.timespan.start,
"end_time": self.timespan.end,
"objfun": lambda x: objfun(x, is_minimized),
"objfun": lambda x: objfun(
x,
self.optimize_interventions.initial_guess[0],
self.optimize_interventions.objective_function_option[0],
),
"qoi": self.qoi.gen_call(),
"risk_bound": self.risk_bound,
"initial_guess_interventions": self.initial_guess_interventions,
"bounds_interventions": self.bounds_interventions,
"static_parameter_interventions": policy_interventions,
"fixed_static_parameter_interventions": static_interventions,
"static_parameter_interventions": optimize_interventions,
"fixed_static_parameter_interventions": fixed_static_parameter_interventions,
"inferred_parameters": inferred_parameters,
"n_samples_ouu": n_samples_ouu,
**extra_options,
Expand Down
8 changes: 5 additions & 3 deletions tests/examples/optimize/input/request.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
"engine": "ciemss",
"user_id": "not_provided",
"model_config_id": "sidarthe",
"policy_interventions": {
"selection": "param_value",
"optimize_interventions": {
"intervention_type": "param_value",
"objective_function_option": ["lower_bound"],
"start_time": [2],
"param_names": ["beta"],
"param_values": [0.02]
"param_values": [0.02],
"initial_guess": [0]
},
"timespan": {
"start": 0,
Expand Down
Loading