Skip to content

Commit

Permalink
Add decorator implementation of state machine (#361)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpaterno authored Jan 25, 2024
1 parent a65e684 commit cd7a3e2
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 31 deletions.
134 changes: 104 additions & 30 deletions firecrown/likelihood/gauss_family/gauss_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

from __future__ import annotations
from enum import Enum
from typing import List, Optional, Tuple, Sequence
from functools import wraps
from typing import List, Optional, Tuple, Sequence, Callable, Union, TypeVar
from typing import final
import warnings

from typing_extensions import ParamSpec
import numpy as np
import numpy.typing as npt
import scipy.linalg
Expand All @@ -34,6 +36,68 @@ class State(Enum):
UPDATED = 3


T = TypeVar("T")
P = ParamSpec("P")


# See https://peps.python.org/pep-0612/ and
# https://stackoverflow.com/questions/66408662/type-annotations-for-decorators
# for how to specify the types of *args and **kwargs, and the return type of
# the method being decorated.


# Beware
def enforce_states(
*,
initial: Union[State, List[State]],
terminal: Optional[State] = None,
failure_message: str,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""This decorator wraps a method, and enforces state machine behavior. If
the object is not in one of the states in initial, an
AssertionError is raised with the given failure_message.
If terminal is None the state of the object is not modified.
If terminal is not None and the call to the wrapped method returns
normally the state of the object is set to terminal.
"""
initials: List[State]
if isinstance(initial, list):
initials = initial
else:
initials = [initial]

def decorator_enforce_states(func: Callable[P, T]) -> Callable[P, T]:
"""This part of the decorator is the closure that actually contains the
values of initials, terminal, and failure_message.
"""

@wraps(func)
def wrapper_repeat(*args: P.args, **kwargs: P.kwargs) -> T:
"""This part of the decorator is the actual wrapped method. It is
responsible for confirming a correct initial state, and
establishing the correct final state if the wrapped method
succeeds.
"""
# The odd use of args[0] instead of self seems to be the only way
# to have both the Python runtime and mypy agree on what is being
# passed to the method, and to allow access to the attribute
# 'state'. Recall that the syntax:
# o.foo()
# calls a bound function object accessible as o.foo; this bound
# function object calls the function foo() passing 'o' as the
# first argument, self.
assert isinstance(args[0], GaussFamily)
assert args[0].state in initials, failure_message
value = func(*args, **kwargs)
if terminal is not None:
args[0].state = terminal
return value

return wrapper_repeat

return decorator_enforce_states


class GaussFamily(Likelihood):
"""GaussFamily is an abstract class. It is the base class for all likelihoods
based on a chi-squared calculation. It provides an implementation of
Expand All @@ -51,12 +115,16 @@ class GaussFamily(Likelihood):
:meth:`calculate_loglike` or :meth:`get_data_vector`, or to reset
the object (returning to the pre-update state) by calling
:meth:`reset`.
This state machine behavior is enforced through the use of the decorator
:meth:`enforce_states`, above.
"""

def __init__(
self,
statistics: Sequence[Statistic],
):
"""Initialize the base class parts of a GaussFamily object."""
super().__init__()
self.state: State = State.INITIALIZED
if len(statistics) == 0:
Expand All @@ -68,28 +136,38 @@ def __init__(
self.cholesky: Optional[npt.NDArray[np.float64]] = None
self.inv_cov: Optional[npt.NDArray[np.float64]] = None

@enforce_states(
initial=State.READY,
terminal=State.UPDATED,
failure_message="read() must be called before update()",
)
def _update(self, _: ParamsMap) -> None:
"""Handle the state resetting required by :class:`GaussFamily`
likelihoods. Any derived class that needs to implement :meth:`_update`
for its own reasons must be sure to do what this does: check the state
at the start of the method, and change the state at the end of the
method."""
assert self.state == State.READY, "read() must be called before update()"
self.state = State.UPDATED

@enforce_states(
initial=State.UPDATED,
terminal=State.READY,
failure_message="update() must be called before reset()",
)
def _reset(self) -> None:
"""Handle the state resetting required by :class:`GaussFamily`
likelihoods. Any derived class that needs to implement :meth:`reset`
for its own reasons must be sure to do what this does: check the state
at the start of the method, and change the state at the end of the
method."""
assert self.state == State.UPDATED, "update() must be called before reset()"
self.state = State.READY

@enforce_states(
initial=State.INITIALIZED,
terminal=State.READY,
failure_message="read() must only be called once",
)
def read(self, sacc_data: sacc.Sacc) -> None:
"""Read the covariance matrix for this likelihood from the SACC file."""

assert self.state == State.INITIALIZED, "read() must only be called once"
if sacc_data.covariance is None:
msg = (
f"The {type(self).__name__} likelihood requires a covariance, "
Expand All @@ -113,67 +191,69 @@ def read(self, sacc_data: sacc.Sacc) -> None:
self.cholesky = scipy.linalg.cholesky(self.cov, lower=True)
self.inv_cov = np.linalg.inv(cov)

self.state = State.READY

@enforce_states(
initial=[State.READY, State.UPDATED],
failure_message="read() must be called before get_cov()",
)
@final
def get_cov(self) -> npt.NDArray[np.float64]:
"""Gets the current covariance matrix."""
assert self._is_ready(), "read() must be called before get_cov()"
assert self.cov is not None
# We do not change the state.
return self.cov

@final
@enforce_states(
initial=[State.READY, State.UPDATED],
failure_message="read() must be called before get_data_vector()",
)
def get_data_vector(self) -> npt.NDArray[np.float64]:
"""Get the data vector from all statistics and concatenate in the right
order."""
assert self._is_ready(), "read() must be called before get_data_vector()"

data_vector_list: List[npt.NDArray[np.float64]] = [
stat.get_data_vector() for stat in self.statistics
]
# We do not change the state.
return np.concatenate(data_vector_list)

@final
@enforce_states(
initial=State.UPDATED,
failure_message="update() must be called before compute_theory_vector()",
)
def compute_theory_vector(self, tools: ModelingTools) -> npt.NDArray[np.float64]:
"""Computes the theory vector using the current instance of pyccl.Cosmology.
:param tools: Current ModelingTools object
"""
assert (
self.state == State.UPDATED
), "update() must be called before compute_theory_vector()"

theory_vector_list: List[npt.NDArray[np.float64]] = [
stat.compute_theory_vector(tools) for stat in self.statistics
]
# We do not change the state
return np.concatenate(theory_vector_list)

@final
@enforce_states(
initial=State.UPDATED,
failure_message="update() must be called before compute()",
)
def compute(
self, tools: ModelingTools
) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
"""Calculate and return both the data and theory vectors."""
assert self.state == State.UPDATED, "update() must be called before compute()"
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
"The use of the `compute` method on Statistic is deprecated."
"The Statistic objects should implement `get_data` and "
"`compute_theory_vector` instead.",
category=DeprecationWarning,
)

# We do not change the state.
return self.get_data_vector(), self.compute_theory_vector(tools)

@final
@enforce_states(
initial=State.UPDATED,
failure_message="update() must be called before compute_chisq()",
)
def compute_chisq(self, tools: ModelingTools) -> float:
"""Calculate and return the chi-squared for the given cosmology."""
assert (
self.state == State.UPDATED
), "update() must be called before compute_chisq()"
theory_vector: npt.NDArray[np.float64]
data_vector: npt.NDArray[np.float64]
residuals: npt.NDArray[np.float64]
Expand All @@ -191,10 +271,4 @@ def compute_chisq(self, tools: ModelingTools) -> float:

x = scipy.linalg.solve_triangular(self.cholesky, residuals, lower=True)
chisq = np.dot(x, x)

# We do not change the state.
return chisq

def _is_ready(self) -> bool:
"""Return True if the state is either READY or UPDATED."""
return self.state in (State.READY, State.UPDATED)
20 changes: 19 additions & 1 deletion tests/likelihood/gauss_family/test_const_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,17 @@ def test_get_cov_works_after_read(trivial_stats, sacc_data_for_trivial_stat):
assert np.all(likelihood.get_cov() == np.diag([4.0, 9.0, 16.0]))


def test_chisquared(trivial_stats, sacc_data_for_trivial_stat, trivial_params):
def test_compute_chisq_fails_before_read(trivial_stats):
"""Note that the error message from the direct call to compute_chisq notes
that update() must be called; this can only be called after read()."""
likelihood = ConstGaussian(statistics=trivial_stats)
with pytest.raises(
AssertionError, match=r"update\(\) must be called before compute_chisq\(\)"
):
_ = likelihood.compute_chisq(ModelingTools())


def test_compute_chisq(trivial_stats, sacc_data_for_trivial_stat, trivial_params):
likelihood = ConstGaussian(statistics=trivial_stats)
likelihood.read(sacc_data_for_trivial_stat)
likelihood.update(trivial_params)
Expand Down Expand Up @@ -77,6 +87,14 @@ def test_missing_covariance(trivial_stats, sacc_with_data_points: sacc.Sacc):
likelihood.read(sacc_with_data_points)


def test_get_data_vector_fails_before_read(trivial_stats):
likelihood = ConstGaussian(statistics=trivial_stats)
with pytest.raises(
AssertionError, match=r"read\(\) must be called before get_data_vector\(\)"
):
_ = likelihood.get_data_vector()


def test_using_good_sacc(
trivial_stats,
sacc_data_for_trivial_stat: sacc.Sacc,
Expand Down

0 comments on commit cd7a3e2

Please sign in to comment.