From 52e36c7f07aa4ed2713782b9df799a72c785f94c Mon Sep 17 00:00:00 2001 From: sabinala Date: Tue, 30 Jul 2024 10:23:52 -0700 Subject: [PATCH 1/3] making failing test for intervention on constant param --- tests/fixtures.py | 1 + tests/test_interfaces.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/tests/fixtures.py b/tests/fixtures.py index 9d2d75802..24354642c 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -182,6 +182,7 @@ def __init__( LOGGING_STEP_SIZES = [5.0] NUM_SAMPLES = [2] +SEIRHD_NPI_STATIC_PARAM_INTERV = [{torch.tensor(10.0): {"delta": torch.tensor(0.2)}}] NON_POS_INTS = [ 3.5, -3, diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 4b58a5710..ed06719b7 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -30,6 +30,7 @@ NON_POS_INTS, NUM_SAMPLES, OPT_MODELS, + SEIRHD_NPI_STATIC_PARAM_INTERV, START_TIMES, check_result_sizes, check_states_match, @@ -731,3 +732,39 @@ def test_errors_for_bad_amrs( logging_step_size, num_samples, ) + + +@pytest.mark.parametrize("sample_method", [sample]) +@pytest.mark.parametrize("model_fixture", MODELS) +@pytest.mark.parametrize("end_time", END_TIMES) +@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) +@pytest.mark.parametrize("num_samples", NUM_SAMPLES) +@pytest.mark.parametrize("seirhd_npi_intervention", SEIRHD_NPI_STATIC_PARAM_INTERV) +def test_intervention_on_constant_param( + sample_method, + model_fixture, + end_time, + logging_step_size, + num_samples, + seirhd_npi_intervention, +): + # Assert that sample returns expected result with intervention on constant parameter + if "SEIRHD_NPI" not in model_fixture.url: + print("skipped") + print(model_fixture.url) + pytest.skip("Only test 'SEIRHD_NPI' models with constant parameter delta") + else: + print("unskipped") + print(model_fixture.url) + processed_result = sample_method( + model_fixture.url, + end_time, + logging_step_size, + num_samples, + static_parameter_interventions=seirhd_npi_intervention, + )["data"] + assert isinstance(processed_result, pd.DataFrame) + assert processed_result.shape[0] == num_samples * len( + torch.arange(start_time, end_time + logging_step_size, logging_step_size) + ) + assert processed_result.shape[1] >= 2 From 5b53a4dcfd311676e8363340c685171bc59eb2ec Mon Sep 17 00:00:00 2001 From: sabinala Date: Tue, 30 Jul 2024 10:41:13 -0700 Subject: [PATCH 2/3] linting --- tests/test_interfaces.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index ed06719b7..91b62720f 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -739,6 +739,7 @@ def test_errors_for_bad_amrs( @pytest.mark.parametrize("end_time", END_TIMES) @pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) @pytest.mark.parametrize("num_samples", NUM_SAMPLES) +@pytest.mark.parametrize("start_time", START_TIMES) @pytest.mark.parametrize("seirhd_npi_intervention", SEIRHD_NPI_STATIC_PARAM_INTERV) def test_intervention_on_constant_param( sample_method, @@ -746,21 +747,19 @@ def test_intervention_on_constant_param( end_time, logging_step_size, num_samples, + start_time, seirhd_npi_intervention, ): # Assert that sample returns expected result with intervention on constant parameter if "SEIRHD_NPI" not in model_fixture.url: - print("skipped") - print(model_fixture.url) pytest.skip("Only test 'SEIRHD_NPI' models with constant parameter delta") else: - print("unskipped") - print(model_fixture.url) processed_result = sample_method( model_fixture.url, end_time, logging_step_size, num_samples, + start_time=start_time, static_parameter_interventions=seirhd_npi_intervention, )["data"] assert isinstance(processed_result, pd.DataFrame) From d06cd99e109d8e45284c77131bf7727d8cc4ac22 Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Thu, 8 Aug 2024 10:23:00 -0400 Subject: [PATCH 3/3] add deterministic parameters to the trace --- pyciemss/compiled_dynamics.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyciemss/compiled_dynamics.py b/pyciemss/compiled_dynamics.py index 2988c5feb..85248440b 100644 --- a/pyciemss/compiled_dynamics.py +++ b/pyciemss/compiled_dynamics.py @@ -101,13 +101,15 @@ def observables(self, X: State[torch.Tensor]) -> State[torch.Tensor]: def instantiate_parameters(self): # Initialize random parameters once before simulating. # This is necessary because the parameters are PyroSample objects. - for k in _compile_param_values(self.src).keys(): + for k, param in _compile_param_values(self.src).items(): param_name = get_name(k) # Separating the persistent parameters from the non-persistent ones # is necessary because the persistent parameters are PyroSample objects representing the distribution, # and should not be modified during intervention. param_val = getattr(self, f"persistent_{param_name}") - self.register_buffer(get_name(k), param_val) + if isinstance(param, torch.Tensor): + pyro.deterministic(f"persistent_{param_name}", param_val) + self.register_buffer(param_name, param_val) def forward( self,