Skip to content

Commit

Permalink
Use refactor fixes early (#1)
Browse files Browse the repository at this point in the history
* Add `ensemble` dir to `setup.cfg`

* Align return type of `calibrate`

* Change "parameters" -> "inferred_parameters"

* added pyciemss logging wrapper to calibrate (ciemss#444)

* Add progress hook (and example usage in test) for `calibrate` (ciemss#445)

* added progress hook for calibrate

* added test with simple logging

* lint

* Fix formatting, linting, etc

---------

Co-authored-by: Sam Witty <samawitty@gmail.com>
  • Loading branch information
fivegrant and SamWitty authored Jan 3, 2024
1 parent f314fc2 commit 2669233
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 13 deletions.
27 changes: 19 additions & 8 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import contextlib
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Union

import pyro
import torch
Expand Down Expand Up @@ -299,6 +299,7 @@ def wrapped_model():
return prepare_interchange_dictionary(samples)


@pyciemss_logging_wrapper
def calibrate(
model_path_or_json: Union[str, Dict],
data_path: str,
Expand All @@ -324,7 +325,8 @@ def calibrate(
verbose: bool = False,
num_particles: int = 1,
deterministic_learnable_parameters: List[str] = [],
) -> Tuple[pyro.nn.PyroModule, float]:
progress_hook: Callable = lambda i, loss: None,
) -> Dict[str, Any]:
"""
Infer parameters for a DynamicalSystem model conditional on data.
This uses variational inference with a mean-field variational family to infer the parameters of the model.
Expand Down Expand Up @@ -398,13 +400,20 @@ def calibrate(
- deterministic_learnable_parameters: List[str]
- A list of parameter names that should be learned deterministically.
- By default, all parameters are learned probabilistically.
- progress_hook: Callable[[int, float], None]
- A function that takes in the current iteration and the current loss.
- This is called at the beginning of each iteration.
- By default, this is a no-op.
- This can be used to implement custom progress bars.
Returns:
- inferred_parameters: pyro.nn.PyroModule
- A Pyro module that contains the inferred parameters of the model.
- This can be passed to `sample` to sample from the model conditional on the data.
- loss: float
- The final loss value of the approximate ELBO loss.
result: Dict[str, Any]
- Dictionary with the following key-value pairs.
- inferred_parameters: pyro.nn.PyroModule
- A Pyro module that contains the inferred parameters of the model.
- This can be passed to `sample` to sample from the model conditional on the data.
- loss: float
- The final loss value of the approximate ELBO loss.
"""

pyro.clear_param_store()
Expand Down Expand Up @@ -492,12 +501,14 @@ def wrapped_model():
pyro.clear_param_store()

for i in range(num_iterations):
# Call a progress hook at the beginning of each iteration. This is used to implement custom progress bars.
progress_hook(i, loss)
loss = svi.step()
if verbose:
if i % 25 == 0:
print(f"iteration {i}: loss = {loss}")

return inferred_parameters, loss
return {"inferred_parameters": inferred_parameters, "loss": loss}


# # TODO
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ packages =
pyciemss
pyciemss.integration_utils
pyciemss.mira_integration
pyciemss.ensemble

[options.package_data]
* = *.json
Expand Down
59 changes: 54 additions & 5 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def test_calibrate_no_kwargs(model_fixture, start_time, end_time, logging_step_s
}

with pyro.poutine.seed(rng_seed=0):
inferred_parameters, _ = calibrate(*calibrate_args, **calibrate_kwargs)
inferred_parameters = calibrate(*calibrate_args, **calibrate_kwargs)[
"inferred_parameters"
]

assert isinstance(inferred_parameters, pyro.nn.PyroModule)

Expand Down Expand Up @@ -255,7 +257,8 @@ def test_calibrate_deterministic(
}

with pyro.poutine.seed(rng_seed=0):
inferred_parameters, _ = calibrate(*calibrate_args, **calibrate_kwargs)
output = calibrate(*calibrate_args, **calibrate_kwargs)
inferred_parameters = output["inferred_parameters"]

assert isinstance(inferred_parameters, pyro.nn.PyroModule)

Expand Down Expand Up @@ -307,7 +310,7 @@ def test_calibrate_interventions(
}

with pyro.poutine.seed(rng_seed=0):
_, loss = calibrate(*calibrate_args, **calibrate_kwargs)
loss = calibrate(*calibrate_args, **calibrate_kwargs)["loss"]

# SETUP INTERVENTION

Expand All @@ -334,8 +337,11 @@ def time_key(time, _):
}

with pyro.poutine.seed(rng_seed=0):
intervened_parameters, intervened_loss = calibrate(
*calibrate_args, **calibrate_kwargs
output = calibrate(*calibrate_args, **calibrate_kwargs)

intervened_parameters, intervened_loss = (
output["inferred_parameters"],
output["loss"],
)

assert intervened_loss != loss
Expand All @@ -347,6 +353,49 @@ def time_key(time, _):
check_result_sizes(result, start_time, end_time, logging_step_size, 1)


@pytest.mark.parametrize("model_fixture", MODELS)
@pytest.mark.parametrize("start_time", START_TIMES)
@pytest.mark.parametrize("end_time", END_TIMES)
@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES)
def test_calibrate_progress_hook(
model_fixture, start_time, end_time, logging_step_size
):
model_url = model_fixture.url

(
_,
calibrate_end_time,
sample_args,
sample_kwargs,
) = setup_calibrate(model_fixture, start_time, end_time, logging_step_size)

calibrate_args = [model_url, model_fixture.data_path]

class TestProgressHook:
def __init__(self):
self.iterations = []
self.losses = []

def __call__(self, iteration, loss):
# Log the loss and iteration number
self.iterations.append(iteration)
self.losses.append(loss)

progress_hook = TestProgressHook()

calibrate_kwargs = {
"data_mapping": model_fixture.data_mapping,
"start_time": start_time,
"progress_hook": progress_hook,
**CALIBRATE_KWARGS,
}

calibrate(*calibrate_args, **calibrate_kwargs)

assert len(progress_hook.iterations) == CALIBRATE_KWARGS["num_iterations"]
assert len(progress_hook.losses) == CALIBRATE_KWARGS["num_iterations"]


@pytest.mark.parametrize("sample_method", SAMPLE_METHODS)
@pytest.mark.parametrize("url", MODEL_URLS)
@pytest.mark.parametrize("start_time", START_TIMES)
Expand Down

0 comments on commit 2669233

Please sign in to comment.