From 266923371b25e1cf0e1e95c825062c495db4e6f3 Mon Sep 17 00:00:00 2001 From: Five Grant <5@fivegrant.com> Date: Wed, 3 Jan 2024 09:12:34 -0600 Subject: [PATCH] Use refactor fixes early (#1) * Add `ensemble` dir to `setup.cfg` * Align return type of `calibrate` * Change "parameters" -> "inferred_parameters" * added pyciemss logging wrapper to calibrate (#444) * Add progress hook (and example usage in test) for `calibrate` (#445) * added progress hook for calibrate * added test with simple logging * lint * Fix formatting, linting, etc --------- Co-authored-by: Sam Witty --- pyciemss/interfaces.py | 27 ++++++++++++------ setup.cfg | 1 + tests/test_interfaces.py | 59 ++++++++++++++++++++++++++++++++++++---- 3 files changed, 74 insertions(+), 13 deletions(-) diff --git a/pyciemss/interfaces.py b/pyciemss/interfaces.py index fffc56346..7c548a2d7 100644 --- a/pyciemss/interfaces.py +++ b/pyciemss/interfaces.py @@ -1,5 +1,5 @@ import contextlib -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union import pyro import torch @@ -299,6 +299,7 @@ def wrapped_model(): return prepare_interchange_dictionary(samples) +@pyciemss_logging_wrapper def calibrate( model_path_or_json: Union[str, Dict], data_path: str, @@ -324,7 +325,8 @@ def calibrate( verbose: bool = False, num_particles: int = 1, deterministic_learnable_parameters: List[str] = [], -) -> Tuple[pyro.nn.PyroModule, float]: + progress_hook: Callable = lambda i, loss: None, +) -> Dict[str, Any]: """ Infer parameters for a DynamicalSystem model conditional on data. This uses variational inference with a mean-field variational family to infer the parameters of the model. @@ -398,13 +400,20 @@ def calibrate( - deterministic_learnable_parameters: List[str] - A list of parameter names that should be learned deterministically. - By default, all parameters are learned probabilistically. + - progress_hook: Callable[[int, float], None] + - A function that takes in the current iteration and the current loss. + - This is called at the beginning of each iteration. + - By default, this is a no-op. + - This can be used to implement custom progress bars. Returns: - - inferred_parameters: pyro.nn.PyroModule - - A Pyro module that contains the inferred parameters of the model. - - This can be passed to `sample` to sample from the model conditional on the data. - - loss: float - - The final loss value of the approximate ELBO loss. + result: Dict[str, Any] + - Dictionary with the following key-value pairs. + - inferred_parameters: pyro.nn.PyroModule + - A Pyro module that contains the inferred parameters of the model. + - This can be passed to `sample` to sample from the model conditional on the data. + - loss: float + - The final loss value of the approximate ELBO loss. """ pyro.clear_param_store() @@ -492,12 +501,14 @@ def wrapped_model(): pyro.clear_param_store() for i in range(num_iterations): + # Call a progress hook at the beginning of each iteration. This is used to implement custom progress bars. + progress_hook(i, loss) loss = svi.step() if verbose: if i % 25 == 0: print(f"iteration {i}: loss = {loss}") - return inferred_parameters, loss + return {"inferred_parameters": inferred_parameters, "loss": loss} # # TODO diff --git a/setup.cfg b/setup.cfg index cb6432b71..915ecd375 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,7 @@ packages = pyciemss pyciemss.integration_utils pyciemss.mira_integration + pyciemss.ensemble [options.package_data] * = *.json diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 0a11a91f1..8af70d3e4 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -210,7 +210,9 @@ def test_calibrate_no_kwargs(model_fixture, start_time, end_time, logging_step_s } with pyro.poutine.seed(rng_seed=0): - inferred_parameters, _ = calibrate(*calibrate_args, **calibrate_kwargs) + inferred_parameters = calibrate(*calibrate_args, **calibrate_kwargs)[ + "inferred_parameters" + ] assert isinstance(inferred_parameters, pyro.nn.PyroModule) @@ -255,7 +257,8 @@ def test_calibrate_deterministic( } with pyro.poutine.seed(rng_seed=0): - inferred_parameters, _ = calibrate(*calibrate_args, **calibrate_kwargs) + output = calibrate(*calibrate_args, **calibrate_kwargs) + inferred_parameters = output["inferred_parameters"] assert isinstance(inferred_parameters, pyro.nn.PyroModule) @@ -307,7 +310,7 @@ def test_calibrate_interventions( } with pyro.poutine.seed(rng_seed=0): - _, loss = calibrate(*calibrate_args, **calibrate_kwargs) + loss = calibrate(*calibrate_args, **calibrate_kwargs)["loss"] # SETUP INTERVENTION @@ -334,8 +337,11 @@ def time_key(time, _): } with pyro.poutine.seed(rng_seed=0): - intervened_parameters, intervened_loss = calibrate( - *calibrate_args, **calibrate_kwargs + output = calibrate(*calibrate_args, **calibrate_kwargs) + + intervened_parameters, intervened_loss = ( + output["inferred_parameters"], + output["loss"], ) assert intervened_loss != loss @@ -347,6 +353,49 @@ def time_key(time, _): check_result_sizes(result, start_time, end_time, logging_step_size, 1) +@pytest.mark.parametrize("model_fixture", MODELS) +@pytest.mark.parametrize("start_time", START_TIMES) +@pytest.mark.parametrize("end_time", END_TIMES) +@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) +def test_calibrate_progress_hook( + model_fixture, start_time, end_time, logging_step_size +): + model_url = model_fixture.url + + ( + _, + calibrate_end_time, + sample_args, + sample_kwargs, + ) = setup_calibrate(model_fixture, start_time, end_time, logging_step_size) + + calibrate_args = [model_url, model_fixture.data_path] + + class TestProgressHook: + def __init__(self): + self.iterations = [] + self.losses = [] + + def __call__(self, iteration, loss): + # Log the loss and iteration number + self.iterations.append(iteration) + self.losses.append(loss) + + progress_hook = TestProgressHook() + + calibrate_kwargs = { + "data_mapping": model_fixture.data_mapping, + "start_time": start_time, + "progress_hook": progress_hook, + **CALIBRATE_KWARGS, + } + + calibrate(*calibrate_args, **calibrate_kwargs) + + assert len(progress_hook.iterations) == CALIBRATE_KWARGS["num_iterations"] + assert len(progress_hook.losses) == CALIBRATE_KWARGS["num_iterations"] + + @pytest.mark.parametrize("sample_method", SAMPLE_METHODS) @pytest.mark.parametrize("url", MODEL_URLS) @pytest.mark.parametrize("start_time", START_TIMES)