Skip to content

Commit

Permalink
adding template for start time and parameter value intervention for `…
Browse files Browse the repository at this point in the history
…optimize` (#585)

* adding template for start time and parameter value intervention

* Update interfaces.ipynb

* Lint

* Adding test

* lint

* fixing tests

* lint
  • Loading branch information
anirban-chaudhuri authored Jul 3, 2024
1 parent 422232c commit 672201f
Show file tree
Hide file tree
Showing 3 changed files with 436 additions and 66 deletions.
433 changes: 372 additions & 61 deletions docs/source/interfaces.ipynb

Large diffs are not rendered by default.

53 changes: 48 additions & 5 deletions pyciemss/integration_utils/intervention_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ def param_value_objective(
start_time: List[torch.Tensor],
param_value: List[Intervention] = [None],
) -> Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]:
if len(param_value) < len(param_name) and param_value[0] is None:
param_size = len(param_name)
if len(param_value) < param_size and param_value[0] is None:
param_value = [None for _ in param_name]
for count in range(len(param_name)):
for count in range(param_size):
if param_value[count] is None:
if not callable(param_value[count]):
param_value[count] = lambda y: torch.tensor(y)
Expand All @@ -20,7 +21,7 @@ def intervention_generator(
x: torch.Tensor,
) -> Dict[float, Dict[str, Intervention]]:
static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}
for count in range(len(param_name)):
for count in range(param_size):
if start_time[count].item() in static_parameter_interventions:
static_parameter_interventions[start_time[count].item()].update(
{param_name[count]: param_value[count](x[count].item())}
Expand All @@ -39,13 +40,16 @@ def intervention_generator(


def start_time_objective(
param_name: List[str], param_value: List[Intervention]
param_name: List[str],
param_value: List[Intervention],
) -> Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]:
param_size = len(param_name)

def intervention_generator(
x: torch.Tensor,
) -> Dict[float, Dict[str, Intervention]]:
static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}
for count in range(len(param_name)):
for count in range(param_size):
if x[count].item() in static_parameter_interventions:
static_parameter_interventions[x[count].item()].update(
{param_name[count]: param_value[count]}
Expand All @@ -59,6 +63,45 @@ def intervention_generator(
return intervention_generator


def start_time_param_value_objective(
param_name: List[str],
param_value: List[Intervention] = [None],
) -> Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]:
param_size = len(param_name)
if len(param_value) < param_size and param_value[0] is None:
param_value = [None for _ in param_name]
for count in range(param_size):
if param_value[count] is None:
if not callable(param_value[count]):
param_value[count] = lambda y: torch.tensor(y)

def intervention_generator(
x: torch.Tensor,
) -> Dict[float, Dict[str, Intervention]]:
assert (
x.size()[0] == param_size * 2
), "Size mismatch: check size for initial_guess_interventions and/or bounds_interventions"
static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}
for count in range(param_size):
if x[count * 2].item() in static_parameter_interventions:
static_parameter_interventions[x[count * 2].item()].update(
{param_name[count]: param_value[count](x[count * 2 + 1].item())}
)
else:
static_parameter_interventions.update(
{
x[count * 2].item(): {
param_name[count]: param_value[count](
x[count * 2 + 1].item()
)
}
}
)
return static_parameter_interventions

return intervention_generator


def combine_static_parameter_interventions(
interventions: List[Dict[torch.Tensor, Dict[str, Intervention]]]
) -> Dict[torch.Tensor, Dict[str, Intervention]]:
Expand Down
16 changes: 16 additions & 0 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyciemss.integration_utils.intervention_builder import (
param_value_objective,
start_time_objective,
start_time_param_value_objective,
)
from pyciemss.ouu.qoi import obs_max_qoi, obs_nday_average_qoi

Expand Down Expand Up @@ -117,6 +118,17 @@ def __init__(
"bounds_interventions": [[0.0], [40.0]],
}

optimize_kwargs_SIRstockflow_time_param = {
"qoi": lambda x: obs_nday_average_qoi(x, ["I_state"], 1),
"risk_bound": 300.0,
"static_parameter_interventions": start_time_param_value_objective(
param_name=["p_cbeta"],
),
"objfun": lambda x: -x[0] * 0.25 / (0.0 - 40.0) + np.abs(0.35 - x[1]) * 1.0,
"initial_guess_interventions": [1.0, 0.15],
"bounds_interventions": [[0.0, 0.1], [40.0, 0.5]],
}

optimize_kwargs_SEIRHD_param_maxQoI = {
"qoi": lambda x: obs_max_qoi(x, ["I_state"]),
"risk_bound": 300.0,
Expand All @@ -139,6 +151,10 @@ def __init__(
os.path.join(MODELS_PATH, "SIR_stockflow.json"),
optimize_kwargs=optimize_kwargs_SIRstockflow_time,
),
ModelFixture(
os.path.join(MODELS_PATH, "SIR_stockflow.json"),
optimize_kwargs=optimize_kwargs_SIRstockflow_time_param,
),
ModelFixture(
os.path.join(MODELS_PATH, "SEIRHD_NPI_Type1_petrinet.json"),
optimize_kwargs=optimize_kwargs_SEIRHD_param_maxQoI,
Expand Down

0 comments on commit 672201f

Please sign in to comment.