Skip to content

Commit

Permalink
Alter and bugfix for service integration (#443)
Browse files Browse the repository at this point in the history
* Add `ensemble` dir to `setup.cfg`

* Align return type of `calibrate`

* Change "parameters" -> "inferred_parameters"

* Fix formatting, linting, etc

* Bump version
  • Loading branch information
fivegrant authored Jan 10, 2024
1 parent 671e405 commit d03ef2b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 14 deletions.
18 changes: 10 additions & 8 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import contextlib
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Union

import pyro
import torch
Expand Down Expand Up @@ -331,7 +331,7 @@ def calibrate(
num_particles: int = 1,
deterministic_learnable_parameters: List[str] = [],
progress_hook: Callable = lambda i, loss: None,
) -> Tuple[pyro.nn.PyroModule, float]:
) -> Dict[str, Any]:
"""
Infer parameters for a DynamicalSystem model conditional on data.
This uses variational inference with a mean-field variational family to infer the parameters of the model.
Expand Down Expand Up @@ -412,11 +412,13 @@ def calibrate(
- This can be used to implement custom progress bars.
Returns:
- inferred_parameters: pyro.nn.PyroModule
- A Pyro module that contains the inferred parameters of the model.
- This can be passed to `sample` to sample from the model conditional on the data.
- loss: float
- The final loss value of the approximate ELBO loss.
result: Dict[str, Any]
- Dictionary with the following key-value pairs.
- inferred_parameters: pyro.nn.PyroModule
- A Pyro module that contains the inferred parameters of the model.
- This can be passed to `sample` to sample from the model conditional on the data.
- loss: float
- The final loss value of the approximate ELBO loss.
"""

pyro.clear_param_store()
Expand Down Expand Up @@ -511,7 +513,7 @@ def wrapped_model():
if i % 25 == 0:
print(f"iteration {i}: loss = {loss}")

return inferred_parameters, loss
return {"inferred_parameters": inferred_parameters, "loss": loss}


# # TODO
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = pyciemss
version = 0.1.0
version = 0.1.1

license = BSD-3-Clause
license_files = LICENSE
Expand All @@ -24,6 +24,7 @@ packages =
pyciemss
pyciemss.integration_utils
pyciemss.mira_integration
pyciemss.ensemble

[options.package_data]
* = *.json
Expand Down
16 changes: 11 additions & 5 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def test_calibrate_no_kwargs(model_fixture, start_time, end_time, logging_step_s
}

with pyro.poutine.seed(rng_seed=0):
inferred_parameters, _ = calibrate(*calibrate_args, **calibrate_kwargs)
inferred_parameters = calibrate(*calibrate_args, **calibrate_kwargs)[
"inferred_parameters"
]

assert isinstance(inferred_parameters, pyro.nn.PyroModule)

Expand Down Expand Up @@ -255,7 +257,8 @@ def test_calibrate_deterministic(
}

with pyro.poutine.seed(rng_seed=0):
inferred_parameters, _ = calibrate(*calibrate_args, **calibrate_kwargs)
output = calibrate(*calibrate_args, **calibrate_kwargs)
inferred_parameters = output["inferred_parameters"]

assert isinstance(inferred_parameters, pyro.nn.PyroModule)

Expand Down Expand Up @@ -307,7 +310,7 @@ def test_calibrate_interventions(
}

with pyro.poutine.seed(rng_seed=0):
_, loss = calibrate(*calibrate_args, **calibrate_kwargs)
loss = calibrate(*calibrate_args, **calibrate_kwargs)["loss"]

# SETUP INTERVENTION

Expand All @@ -334,8 +337,11 @@ def time_key(time, _):
}

with pyro.poutine.seed(rng_seed=0):
intervened_parameters, intervened_loss = calibrate(
*calibrate_args, **calibrate_kwargs
output = calibrate(*calibrate_args, **calibrate_kwargs)

intervened_parameters, intervened_loss = (
output["inferred_parameters"],
output["loss"],
)

assert intervened_loss != loss
Expand Down

0 comments on commit d03ef2b

Please sign in to comment.