Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize interface updates #88

Merged
merged 3 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions service/models/operations/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import numpy as np
import torch
from pydantic import BaseModel, Field, Extra
from models.base import OperationRequest, Timespan
from models.base import OperationRequest, Timespan, InterventionObject
from pyciemss.integration_utils.intervention_builder import (
param_value_objective,
start_time_objective,
)

from pyciemss.ouu.qoi import obs_nday_average_qoi, obs_max_qoi

from models.converters import convert_to_static_interventions
from utils.tds import fetch_model, fetch_inferred_parameters


Expand Down Expand Up @@ -80,7 +80,7 @@ class Optimize(OperationRequest):
pyciemss_lib_function: ClassVar[str] = "optimize"
model_config_id: str = Field(..., example="ba8da8d4-047d-11ee-be56")
timespan: Timespan = Timespan(start=0, end=90)
interventions: InterventionObjective
policy_interventions: InterventionObjective
step_size: float = 1.0
qoi: QOI
risk_bound: float
Expand All @@ -90,29 +90,37 @@ class Optimize(OperationRequest):
None,
description="optional extra system specific arguments for advanced use cases",
)
fixed_static_parameter_interventions: List[InterventionObject] = Field(
default_factory=list,
description="The interventions provided via the model config which are not being optimized",
)

def gen_pyciemss_args(self, job_id):
# Get model from TDS
amr_path = fetch_model(self.model_config_id, job_id)

intervention_type = self.interventions.selection
fixed_static_parameter_interventions = convert_to_static_interventions(
self.fixed_static_parameter_interventions
)
intervention_type = self.policy_interventions.selection
if intervention_type == "param_value":
assert self.interventions.start_time is not None
start_time = [torch.tensor(time) for time in self.interventions.start_time]
param_value = [None] * len(self.interventions.param_names)
assert self.policy_interventions.start_time is not None
start_time = [
torch.tensor(time) for time in self.policy_interventions.start_time
]
param_value = [None] * len(self.policy_interventions.param_names)

interventions = param_value_objective(
policy_interventions = param_value_objective(
start_time=start_time,
param_name=self.interventions.param_names,
param_name=self.policy_interventions.param_names,
param_value=param_value,
)
else:
assert self.interventions.param_values is not None
assert self.policy_interventions.param_values is not None
param_value = [
torch.tensor(value) for value in self.interventions.param_values
torch.tensor(value) for value in self.policy_interventions.param_values
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to check the torch.tensors have correct dtypes needed in pyciemss?

]
interventions = start_time_objective(
param_name=self.interventions.param_names,
policy_interventions = start_time_objective(
param_name=self.policy_interventions.param_names,
param_value=param_value,
)

Expand All @@ -133,7 +141,8 @@ def gen_pyciemss_args(self, job_id):
"risk_bound": self.risk_bound,
"initial_guess_interventions": self.initial_guess_interventions,
"bounds_interventions": self.bounds_interventions,
"static_parameter_interventions": interventions,
"static_parameter_interventions": policy_interventions,
"fixed_static_parameter_interventions": fixed_static_parameter_interventions,
"inferred_parameters": inferred_parameters,
"n_samples_ouu": n_samples_ouu,
**extra_options,
Expand Down
2 changes: 1 addition & 1 deletion tests/examples/optimize/input/request.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"engine": "ciemss",
"user_id": "not_provided",
"model_config_id": "sidarthe",
"interventions": {
"policy_interventions": {
"selection": "param_value",
"start_time": [2],
"param_names": ["beta"],
Expand Down
Loading