Skip to content

Commit

Permalink
Allow for interventions on constant parameters (#597)
Browse files Browse the repository at this point in the history
* making failing test for intervention on constant param

* linting

* add deterministic parameters to the trace

---------

Co-authored-by: sabinala <sabina.altus@pnnl.gov>
  • Loading branch information
SamWitty and sabinala authored Aug 8, 2024
1 parent adeb6b9 commit f5c5bec
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
6 changes: 4 additions & 2 deletions pyciemss/compiled_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,15 @@ def observables(self, X: State[torch.Tensor]) -> State[torch.Tensor]:
def instantiate_parameters(self):
# Initialize random parameters once before simulating.
# This is necessary because the parameters are PyroSample objects.
for k in _compile_param_values(self.src).keys():
for k, param in _compile_param_values(self.src).items():
param_name = get_name(k)
# Separating the persistent parameters from the non-persistent ones
# is necessary because the persistent parameters are PyroSample objects representing the distribution,
# and should not be modified during intervention.
param_val = getattr(self, f"persistent_{param_name}")
self.register_buffer(get_name(k), param_val)
if isinstance(param, torch.Tensor):
pyro.deterministic(f"persistent_{param_name}", param_val)
self.register_buffer(param_name, param_val)

def forward(
self,
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def __init__(
LOGGING_STEP_SIZES = [5.0]

NUM_SAMPLES = [2]
SEIRHD_NPI_STATIC_PARAM_INTERV = [{torch.tensor(10.0): {"delta": torch.tensor(0.2)}}]
NON_POS_INTS = [
3.5,
-3,
Expand Down
36 changes: 36 additions & 0 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
NON_POS_INTS,
NUM_SAMPLES,
OPT_MODELS,
SEIRHD_NPI_STATIC_PARAM_INTERV,
START_TIMES,
check_result_sizes,
check_states_match,
Expand Down Expand Up @@ -737,3 +738,38 @@ def test_errors_for_bad_amrs(
logging_step_size,
num_samples,
)


@pytest.mark.parametrize("sample_method", [sample])
@pytest.mark.parametrize("model_fixture", MODELS)
@pytest.mark.parametrize("end_time", END_TIMES)
@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES)
@pytest.mark.parametrize("num_samples", NUM_SAMPLES)
@pytest.mark.parametrize("start_time", START_TIMES)
@pytest.mark.parametrize("seirhd_npi_intervention", SEIRHD_NPI_STATIC_PARAM_INTERV)
def test_intervention_on_constant_param(
sample_method,
model_fixture,
end_time,
logging_step_size,
num_samples,
start_time,
seirhd_npi_intervention,
):
# Assert that sample returns expected result with intervention on constant parameter
if "SEIRHD_NPI" not in model_fixture.url:
pytest.skip("Only test 'SEIRHD_NPI' models with constant parameter delta")
else:
processed_result = sample_method(
model_fixture.url,
end_time,
logging_step_size,
num_samples,
start_time=start_time,
static_parameter_interventions=seirhd_npi_intervention,
)["data"]
assert isinstance(processed_result, pd.DataFrame)
assert processed_result.shape[0] == num_samples * len(
torch.arange(start_time, end_time + logging_step_size, logging_step_size)
)
assert processed_result.shape[1] >= 2

0 comments on commit f5c5bec

Please sign in to comment.