diff --git a/firecrown/likelihood/gauss_family/gauss_family.py b/firecrown/likelihood/gauss_family/gauss_family.py index 69a8224c..84179a0e 100644 --- a/firecrown/likelihood/gauss_family/gauss_family.py +++ b/firecrown/likelihood/gauss_family/gauss_family.py @@ -9,10 +9,12 @@ from __future__ import annotations from enum import Enum -from typing import List, Optional, Tuple, Sequence +from functools import wraps +from typing import List, Optional, Tuple, Sequence, Callable, Union, TypeVar from typing import final import warnings +from typing_extensions import ParamSpec import numpy as np import numpy.typing as npt import scipy.linalg @@ -34,6 +36,68 @@ class State(Enum): UPDATED = 3 +T = TypeVar("T") +P = ParamSpec("P") + + +# See https://peps.python.org/pep-0612/ and +# https://stackoverflow.com/questions/66408662/type-annotations-for-decorators +# for how to specify the types of *args and **kwargs, and the return type of +# the method being decorated. + + +# Beware +def enforce_states( + *, + initial: Union[State, List[State]], + terminal: Optional[State] = None, + failure_message: str, +) -> Callable[[Callable[P, T]], Callable[P, T]]: + """This decorator wraps a method, and enforces state machine behavior. If + the object is not in one of the states in initial, an + AssertionError is raised with the given failure_message. + If terminal is None the state of the object is not modified. + If terminal is not None and the call to the wrapped method returns + normally the state of the object is set to terminal. + """ + initials: List[State] + if isinstance(initial, list): + initials = initial + else: + initials = [initial] + + def decorator_enforce_states(func: Callable[P, T]) -> Callable[P, T]: + """This part of the decorator is the closure that actually contains the + values of initials, terminal, and failure_message. + """ + + @wraps(func) + def wrapper_repeat(*args: P.args, **kwargs: P.kwargs) -> T: + """This part of the decorator is the actual wrapped method. It is + responsible for confirming a correct initial state, and + establishing the correct final state if the wrapped method + succeeds. + """ + # The odd use of args[0] instead of self seems to be the only way + # to have both the Python runtime and mypy agree on what is being + # passed to the method, and to allow access to the attribute + # 'state'. Recall that the syntax: + # o.foo() + # calls a bound function object accessible as o.foo; this bound + # function object calls the function foo() passing 'o' as the + # first argument, self. + assert isinstance(args[0], GaussFamily) + assert args[0].state in initials, failure_message + value = func(*args, **kwargs) + if terminal is not None: + args[0].state = terminal + return value + + return wrapper_repeat + + return decorator_enforce_states + + class GaussFamily(Likelihood): """GaussFamily is an abstract class. It is the base class for all likelihoods based on a chi-squared calculation. It provides an implementation of @@ -51,12 +115,16 @@ class GaussFamily(Likelihood): :meth:`calculate_loglike` or :meth:`get_data_vector`, or to reset the object (returning to the pre-update state) by calling :meth:`reset`. + + This state machine behavior is enforced through the use of the decorator + :meth:`enforce_states`, above. """ def __init__( self, statistics: Sequence[Statistic], ): + """Initialize the base class parts of a GaussFamily object.""" super().__init__() self.state: State = State.INITIALIZED if len(statistics) == 0: @@ -68,28 +136,38 @@ def __init__( self.cholesky: Optional[npt.NDArray[np.float64]] = None self.inv_cov: Optional[npt.NDArray[np.float64]] = None + @enforce_states( + initial=State.READY, + terminal=State.UPDATED, + failure_message="read() must be called before update()", + ) def _update(self, _: ParamsMap) -> None: """Handle the state resetting required by :class:`GaussFamily` likelihoods. Any derived class that needs to implement :meth:`_update` for its own reasons must be sure to do what this does: check the state at the start of the method, and change the state at the end of the method.""" - assert self.state == State.READY, "read() must be called before update()" - self.state = State.UPDATED + @enforce_states( + initial=State.UPDATED, + terminal=State.READY, + failure_message="update() must be called before reset()", + ) def _reset(self) -> None: """Handle the state resetting required by :class:`GaussFamily` likelihoods. Any derived class that needs to implement :meth:`reset` for its own reasons must be sure to do what this does: check the state at the start of the method, and change the state at the end of the method.""" - assert self.state == State.UPDATED, "update() must be called before reset()" - self.state = State.READY + @enforce_states( + initial=State.INITIALIZED, + terminal=State.READY, + failure_message="read() must only be called once", + ) def read(self, sacc_data: sacc.Sacc) -> None: """Read the covariance matrix for this likelihood from the SACC file.""" - assert self.state == State.INITIALIZED, "read() must only be called once" if sacc_data.covariance is None: msg = ( f"The {type(self).__name__} likelihood requires a covariance, " @@ -113,50 +191,53 @@ def read(self, sacc_data: sacc.Sacc) -> None: self.cholesky = scipy.linalg.cholesky(self.cov, lower=True) self.inv_cov = np.linalg.inv(cov) - self.state = State.READY - + @enforce_states( + initial=[State.READY, State.UPDATED], + failure_message="read() must be called before get_cov()", + ) @final def get_cov(self) -> npt.NDArray[np.float64]: """Gets the current covariance matrix.""" - assert self._is_ready(), "read() must be called before get_cov()" assert self.cov is not None - # We do not change the state. return self.cov @final + @enforce_states( + initial=[State.READY, State.UPDATED], + failure_message="read() must be called before get_data_vector()", + ) def get_data_vector(self) -> npt.NDArray[np.float64]: """Get the data vector from all statistics and concatenate in the right order.""" - assert self._is_ready(), "read() must be called before get_data_vector()" - data_vector_list: List[npt.NDArray[np.float64]] = [ stat.get_data_vector() for stat in self.statistics ] - # We do not change the state. return np.concatenate(data_vector_list) @final + @enforce_states( + initial=State.UPDATED, + failure_message="update() must be called before compute_theory_vector()", + ) def compute_theory_vector(self, tools: ModelingTools) -> npt.NDArray[np.float64]: """Computes the theory vector using the current instance of pyccl.Cosmology. :param tools: Current ModelingTools object """ - assert ( - self.state == State.UPDATED - ), "update() must be called before compute_theory_vector()" - theory_vector_list: List[npt.NDArray[np.float64]] = [ stat.compute_theory_vector(tools) for stat in self.statistics ] - # We do not change the state return np.concatenate(theory_vector_list) @final + @enforce_states( + initial=State.UPDATED, + failure_message="update() must be called before compute()", + ) def compute( self, tools: ModelingTools ) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: """Calculate and return both the data and theory vectors.""" - assert self.state == State.UPDATED, "update() must be called before compute()" warnings.simplefilter("always", DeprecationWarning) warnings.warn( "The use of the `compute` method on Statistic is deprecated." @@ -164,16 +245,15 @@ def compute( "`compute_theory_vector` instead.", category=DeprecationWarning, ) - - # We do not change the state. return self.get_data_vector(), self.compute_theory_vector(tools) @final + @enforce_states( + initial=State.UPDATED, + failure_message="update() must be called before compute_chisq()", + ) def compute_chisq(self, tools: ModelingTools) -> float: """Calculate and return the chi-squared for the given cosmology.""" - assert ( - self.state == State.UPDATED - ), "update() must be called before compute_chisq()" theory_vector: npt.NDArray[np.float64] data_vector: npt.NDArray[np.float64] residuals: npt.NDArray[np.float64] @@ -191,10 +271,4 @@ def compute_chisq(self, tools: ModelingTools) -> float: x = scipy.linalg.solve_triangular(self.cholesky, residuals, lower=True) chisq = np.dot(x, x) - - # We do not change the state. return chisq - - def _is_ready(self) -> bool: - """Return True if the state is either READY or UPDATED.""" - return self.state in (State.READY, State.UPDATED) diff --git a/tests/likelihood/gauss_family/test_const_gaussian.py b/tests/likelihood/gauss_family/test_const_gaussian.py index 5ac4e8d7..c0f9f997 100644 --- a/tests/likelihood/gauss_family/test_const_gaussian.py +++ b/tests/likelihood/gauss_family/test_const_gaussian.py @@ -33,7 +33,17 @@ def test_get_cov_works_after_read(trivial_stats, sacc_data_for_trivial_stat): assert np.all(likelihood.get_cov() == np.diag([4.0, 9.0, 16.0])) -def test_chisquared(trivial_stats, sacc_data_for_trivial_stat, trivial_params): +def test_compute_chisq_fails_before_read(trivial_stats): + """Note that the error message from the direct call to compute_chisq notes + that update() must be called; this can only be called after read().""" + likelihood = ConstGaussian(statistics=trivial_stats) + with pytest.raises( + AssertionError, match=r"update\(\) must be called before compute_chisq\(\)" + ): + _ = likelihood.compute_chisq(ModelingTools()) + + +def test_compute_chisq(trivial_stats, sacc_data_for_trivial_stat, trivial_params): likelihood = ConstGaussian(statistics=trivial_stats) likelihood.read(sacc_data_for_trivial_stat) likelihood.update(trivial_params) @@ -77,6 +87,14 @@ def test_missing_covariance(trivial_stats, sacc_with_data_points: sacc.Sacc): likelihood.read(sacc_with_data_points) +def test_get_data_vector_fails_before_read(trivial_stats): + likelihood = ConstGaussian(statistics=trivial_stats) + with pytest.raises( + AssertionError, match=r"read\(\) must be called before get_data_vector\(\)" + ): + _ = likelihood.get_data_vector() + + def test_using_good_sacc( trivial_stats, sacc_data_for_trivial_stat: sacc.Sacc,