diff --git a/pyciemss/interfaces.py b/pyciemss/interfaces.py index 120c44370..e68be4410 100644 --- a/pyciemss/interfaces.py +++ b/pyciemss/interfaces.py @@ -1,5 +1,5 @@ import contextlib -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union import pyro import torch @@ -331,7 +331,7 @@ def calibrate( num_particles: int = 1, deterministic_learnable_parameters: List[str] = [], progress_hook: Callable = lambda i, loss: None, -) -> Tuple[pyro.nn.PyroModule, float]: +) -> Dict[str, Any]: """ Infer parameters for a DynamicalSystem model conditional on data. This uses variational inference with a mean-field variational family to infer the parameters of the model. @@ -412,11 +412,13 @@ def calibrate( - This can be used to implement custom progress bars. Returns: - - inferred_parameters: pyro.nn.PyroModule - - A Pyro module that contains the inferred parameters of the model. - - This can be passed to `sample` to sample from the model conditional on the data. - - loss: float - - The final loss value of the approximate ELBO loss. + result: Dict[str, Any] + - Dictionary with the following key-value pairs. + - inferred_parameters: pyro.nn.PyroModule + - A Pyro module that contains the inferred parameters of the model. + - This can be passed to `sample` to sample from the model conditional on the data. + - loss: float + - The final loss value of the approximate ELBO loss. """ pyro.clear_param_store() @@ -511,7 +513,7 @@ def wrapped_model(): if i % 25 == 0: print(f"iteration {i}: loss = {loss}") - return inferred_parameters, loss + return {"inferred_parameters": inferred_parameters, "loss": loss} # # TODO diff --git a/setup.cfg b/setup.cfg index cb6432b71..4ad546b3e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = pyciemss -version = 0.1.0 +version = 0.1.1 license = BSD-3-Clause license_files = LICENSE @@ -24,6 +24,7 @@ packages = pyciemss pyciemss.integration_utils pyciemss.mira_integration + pyciemss.ensemble [options.package_data] * = *.json diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 58a0913d3..8af70d3e4 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -210,7 +210,9 @@ def test_calibrate_no_kwargs(model_fixture, start_time, end_time, logging_step_s } with pyro.poutine.seed(rng_seed=0): - inferred_parameters, _ = calibrate(*calibrate_args, **calibrate_kwargs) + inferred_parameters = calibrate(*calibrate_args, **calibrate_kwargs)[ + "inferred_parameters" + ] assert isinstance(inferred_parameters, pyro.nn.PyroModule) @@ -255,7 +257,8 @@ def test_calibrate_deterministic( } with pyro.poutine.seed(rng_seed=0): - inferred_parameters, _ = calibrate(*calibrate_args, **calibrate_kwargs) + output = calibrate(*calibrate_args, **calibrate_kwargs) + inferred_parameters = output["inferred_parameters"] assert isinstance(inferred_parameters, pyro.nn.PyroModule) @@ -307,7 +310,7 @@ def test_calibrate_interventions( } with pyro.poutine.seed(rng_seed=0): - _, loss = calibrate(*calibrate_args, **calibrate_kwargs) + loss = calibrate(*calibrate_args, **calibrate_kwargs)["loss"] # SETUP INTERVENTION @@ -334,8 +337,11 @@ def time_key(time, _): } with pyro.poutine.seed(rng_seed=0): - intervened_parameters, intervened_loss = calibrate( - *calibrate_args, **calibrate_kwargs + output = calibrate(*calibrate_args, **calibrate_kwargs) + + intervened_parameters, intervened_loss = ( + output["inferred_parameters"], + output["loss"], ) assert intervened_loss != loss