Skip to content

Commit

Permalink
make distributional parameters optional
Browse files Browse the repository at this point in the history
  • Loading branch information
SamWitty committed Aug 28, 2024
1 parent 040ac12 commit 112f93c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
4 changes: 3 additions & 1 deletion tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ def __init__(
data_mapping: Dict[str, str] = {},
data_mapped_to_observable: bool = False,
optimize_kwargs: Dict[str, Any] = None,
has_distributional_parameters: bool = True,
):
self.url = url
self.important_parameter = important_parameter
self.data_path = data_path
self.data_mapping = data_mapping
self.data_mapped_to_observable = data_mapped_to_observable
self.optimize_kwargs = optimize_kwargs
self.has_distributional_parameters = has_distributional_parameters


# See https://github.com/DARPA-ASKEM/Model-Representations/issues/62 for discussion of valid models.
Expand Down Expand Up @@ -85,7 +87,7 @@ def __init__(
ModelFixture(
os.path.join(MODELS_PATH, "LV_rabbits_wolves_model03_regnet.json"), "beta"
),
ModelFixture(os.path.join(MODELS_PATH, "LacOperon.json")),
ModelFixture(os.path.join(MODELS_PATH, "LacOperon.json"), "k_1", has_distributional_parameters=False),
]

STOCKFLOW_MODELS = [
Expand Down
9 changes: 6 additions & 3 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,16 @@ def setup_calibrate(model_fixture, start_time, end_time, logging_step_size):


@pytest.mark.parametrize("sample_method", SAMPLE_METHODS)
@pytest.mark.parametrize("model_url", MODEL_URLS)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("start_time", START_TIMES)
@pytest.mark.parametrize("end_time", END_TIMES)
@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES)
@pytest.mark.parametrize("num_samples", NUM_SAMPLES)
def test_sample_no_interventions(
sample_method, model_url, start_time, end_time, logging_step_size, num_samples
sample_method, model, start_time, end_time, logging_step_size, num_samples
):
model_url = model.url

with pyro.poutine.seed(rng_seed=0):
result1 = sample_method(
model_url, end_time, logging_step_size, num_samples, start_time=start_time
Expand All @@ -115,7 +117,8 @@ def test_sample_no_interventions(
check_result_sizes(result, start_time, end_time, logging_step_size, num_samples)

check_states_match(result1, result2)
check_states_match_in_all_but_values(result1, result3)
if model.has_distributional_parameters:
check_states_match_in_all_but_values(result1, result3)

if sample_method.__name__ == "dummy_ensemble_sample":
assert "total_state" in result1.keys()
Expand Down

0 comments on commit 112f93c

Please sign in to comment.