Skip to content

Commit

Permalink
Yet even more docstrings (#445)
Browse files Browse the repository at this point in the history
* Improve docstrings and type annotations
* Fix bug in NumberCounts property computation
  • Loading branch information
marcpaterno authored Aug 13, 2024
1 parent a6c217a commit bd70476
Show file tree
Hide file tree
Showing 11 changed files with 294 additions and 89 deletions.
2 changes: 1 addition & 1 deletion firecrown/likelihood/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Classes used to represent likelihoods and functions to support them.
"""Classes used to represent likelihoods, and functions to support them.
Subpackages contain specific likelihood implementations, e.g., Gaussian and Student-t.
The submodule :mod:`firecrown.likelihood.likelihood` contain the abstract base class for
Expand Down
43 changes: 28 additions & 15 deletions firecrown/likelihood/binned_cluster_number_counts.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
"""This module holds classes needed to predict the binned cluster number counts.
The binned cluster number counts statistic predicts the number of galaxy
clusters within a single redshift and mass bin.
"""
"""Binned cluster number counts statistic support."""

from __future__ import annotations

Expand All @@ -26,11 +22,7 @@


class BinnedClusterNumberCounts(Statistic):
"""The Binned Cluster Number Counts statistic.
This class will make a prediction for the number of clusters in a z, mass bin
and compare that prediction to the data provided in the sacc file.
"""
"""A statistic representing the number of clusters in a z, mass bin."""

def __init__(
self,
Expand All @@ -39,6 +31,13 @@ def __init__(
cluster_recipe: ClusterRecipe,
systematics: None | list[SourceSystematic] = None,
):
"""Initialize this statistic.
:param cluster_properties: The cluster observables to use.
:param survey_name: The name of the survey to use.
:param cluster_recipe: The cluster recipe to use.
:param systematics: The systematics to apply to this statistic.
"""
super().__init__()
self.systematics = systematics or []
self.theory_vector: None | TheoryVector = None
Expand All @@ -50,7 +49,10 @@ def __init__(
self.bins: list[SaccBin] = []

def read(self, sacc_data: sacc.Sacc) -> None:
"""Read the data for this statistic and mark it as ready for use."""
"""Read the data for this statistic and mark it as ready for use.
:param sacc_data: The data in the sacc format.
"""
# Build the data vector and indices needed for the likelihood
if self.cluster_properties == ClusterProperty.NONE:
raise ValueError("You must specify at least one cluster property.")
Expand All @@ -77,12 +79,19 @@ def read(self, sacc_data: sacc.Sacc) -> None:
super().read(sacc_data)

def get_data_vector(self) -> DataVector:
"""Gets the statistic data vector."""
"""Gets the statistic data vector.
:return: The statistic data vector.
"""
assert self.data_vector is not None
return self.data_vector

def _compute_theory_vector(self, tools: ModelingTools) -> TheoryVector:
"""Compute a statistic from sources, concrete implementation."""
"""Compute a statistic from sources, concrete implementation.
:param tools: The modeling tools used to compute the statistic.
:return: The computed statistic.
"""
assert tools.cluster_abundance is not None

theory_vector_list: list[float] = []
Expand Down Expand Up @@ -116,6 +125,9 @@ def get_binned_cluster_property(
Using the data from the sacc file, this function evaluates the likelihood for
a single point of the parameter space, and returns the predicted mean mass of
the clusters in each bin.
:param cluster_counts: The number of clusters in each bin.
:param cluster_properties: The cluster observables to use.
"""
assert tools.cluster_abundance is not None

Expand All @@ -124,8 +136,6 @@ def get_binned_cluster_property(
total_observable = self.cluster_recipe.evaluate_theory_prediction(
tools.cluster_abundance, this_bin, self.sky_area, cluster_properties
)
cluster_counts.append(counts)

mean_observable = total_observable / counts
mean_values.append(mean_observable)

Expand All @@ -137,6 +147,9 @@ def get_binned_cluster_counts(self, tools: ModelingTools) -> list[float]:
Using the data from the sacc file, this function evaluates the likelihood for
a single point of the parameter space, and returns the predicted number of
clusters in each bin.
:param tools: The modeling tools used to compute the statistic.
:return: The number of clusters in each bin.
"""
assert tools.cluster_abundance is not None

Expand Down
2 changes: 2 additions & 0 deletions firecrown/likelihood/gauss_family/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
"""Backward compatibility support for deprecated directory structure."""

# flake8: noqa
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
"""Backward compatibility support for deprecated directory structure."""

# flake8: noqa
85 changes: 73 additions & 12 deletions firecrown/likelihood/gaussfamily.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Support for the family of Gaussian likelihood."""
"""Support for the family of Gaussian likelihoods."""

from __future__ import annotations

Expand Down Expand Up @@ -30,7 +30,12 @@


class State(Enum):
"""The states used in GaussFamily."""
"""The states used in GaussFamily.
GaussFamily and all subclasses enforce a statemachine behavior based on
these states to ensure that the necessary initialization and setup is done
in the correct order.
"""

INITIALIZED = 1
READY = 2
Expand Down Expand Up @@ -62,6 +67,12 @@ def enforce_states(
If terminal is None the state of the object is not modified.
If terminal is not None and the call to the wrapped method returns
normally the state of the object is set to terminal.
:param initial: The initial states allowable for the wrapped method
:param terminal: The terminal state ensured for the wrapped method. None
indicates no state change happens.
:param failure_message: The failure message for the AssertionError raised
:return: The wrapped method
"""
initials: list[State]
if isinstance(initial, list):
Expand All @@ -74,6 +85,9 @@ def decorator_enforce_states(func: Callable[P, T]) -> Callable[P, T]:
This closure is what actually contains the values of initials, terminal, and
failure_message.
:param func: The method to be wrapped
:return: The wrapped method
"""

@wraps(func)
Expand Down Expand Up @@ -132,8 +146,11 @@ class GaussFamily(Likelihood):
def __init__(
self,
statistics: Sequence[Statistic],
):
"""Initialize the base class parts of a GaussFamily object."""
) -> None:
"""Initialize the base class parts of a GaussFamily object.
:param statistics: A list of statistics to be include in chisquared calculations
"""
super().__init__()
self.state: State = State.INITIALIZED
if len(statistics) == 0:
Expand All @@ -160,7 +177,12 @@ def __init__(
def create_ready(
cls, statistics: Sequence[Statistic], covariance: npt.NDArray[np.float64]
) -> GaussFamily:
"""Create a GaussFamily object in the READY state."""
"""Create a GaussFamily object in the READY state.
:param statistics: A list of statistics to be include in chisquared calculations
:param covariance: The covariance matrix of the statistics
:return: A ready GaussFamily object
"""
obj = cls(statistics)
obj._set_covariance(covariance)
obj.state = State.READY
Expand All @@ -178,6 +200,8 @@ def _update(self, _: ParamsMap) -> None:
for its own reasons must be sure to do what this does: check the state
at the start of the method, and change the state at the end of the
method.
:param _: a ParamsMap object, not used
"""

@enforce_states(
Expand All @@ -201,7 +225,10 @@ def _reset(self) -> None:
failure_message="read() must only be called once",
)
def read(self, sacc_data: sacc.Sacc) -> None:
"""Read the covariance matrix for this likelihood from the SACC file."""
"""Read the covariance matrix for this likelihood from the SACC file.
:param sacc_data: The SACC data object to be read
"""
if sacc_data.covariance is None:
msg = (
f"The {type(self).__name__} likelihood requires a covariance, "
Expand All @@ -216,11 +243,13 @@ def read(self, sacc_data: sacc.Sacc) -> None:

self._set_covariance(covariance)

def _set_covariance(self, covariance):
def _set_covariance(self, covariance: npt.NDArray[np.float64]) -> None:
"""Set the covariance matrix.
This method is used to set the covariance matrix and perform the
necessary calculations to prepare the likelihood for computation.
:param covariance: The covariance matrix for this likelihood
"""
indices_list = []
data_vector_list = []
Expand Down Expand Up @@ -276,6 +305,7 @@ def get_cov(
:param statistic: The statistic for which the sub-covariance matrix
should be returned. If not specified, return the covariance of all
statistics.
:return: The covariance matrix (or portion thereof)
"""
assert self.cov is not None
if statistic is None:
Expand All @@ -301,7 +331,10 @@ def get_cov(
failure_message="read() must be called before get_data_vector()",
)
def get_data_vector(self) -> npt.NDArray[np.float64]:
"""Get the data vector from all statistics in the right order."""
"""Get the data vector from all statistics in the right order.
:return: The data vector
"""
assert self.data_vector is not None
return self.data_vector

Expand All @@ -315,6 +348,7 @@ def compute_theory_vector(self, tools: ModelingTools) -> npt.NDArray[np.float64]
"""Computes the theory vector using the current instance of pyccl.Cosmology.
:param tools: Current ModelingTools object
:return: The computed theory vector
"""
theory_vector_list: list[npt.NDArray[np.float64]] = [
stat.compute_theory_vector(tools) for stat in self.statistics
Expand All @@ -329,7 +363,10 @@ def compute_theory_vector(self, tools: ModelingTools) -> npt.NDArray[np.float64]
"get_theory_vector()",
)
def get_theory_vector(self) -> npt.NDArray[np.float64]:
"""Get the theory vector from all statistics in the right order."""
"""Get the already-computed theory vector from all statistics.
:return: The theory vector, with all statistics in the right order
"""
assert (
self.theory_vector is not None
), "theory_vector is None after compute_theory_vector() has been called"
Expand All @@ -343,7 +380,14 @@ def get_theory_vector(self) -> npt.NDArray[np.float64]:
def compute(
self, tools: ModelingTools
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
"""Calculate and return both the data and theory vectors."""
"""Calculate and return both the data and theory vectors.
This method is dprecated and will be removed in a future version of Firecrown.
:param tools: the ModelingTools to be used in the calculation of the
theory vector
:return: a tuple containing the data vector and the theory vector
"""
warnings.warn(
"The use of the `compute` method on Statistic is deprecated."
"The Statistic objects should implement `get_data` and "
Expand All @@ -359,7 +403,12 @@ def compute(
failure_message="update() must be called before compute_chisq()",
)
def compute_chisq(self, tools: ModelingTools) -> float:
"""Calculate and return the chi-squared for the given cosmology."""
"""Calculate and return the chi-squared for the given cosmology.
:param tools: the ModelingTools to be used in the calculation of the
theory vector
:return: the chi-squared
"""
theory_vector: npt.NDArray[np.float64]
data_vector: npt.NDArray[np.float64]
residuals: npt.NDArray[np.float64]
Expand All @@ -386,6 +435,10 @@ def get_sacc_indices(
"""Get the SACC indices of the statistic or list of statistics.
If no statistic is given, get the indices of all statistics of the likelihood.
:param statistics: The statistic or list of statistics for which the
SACC indices are desired
:return: The SACC indices
"""
if statistic is None:
statistic = [stat.statistic for stat in self.statistics]
Expand All @@ -409,7 +462,15 @@ def get_sacc_indices(
def make_realization(
self, sacc_data: sacc.Sacc, add_noise: bool = True, strict: bool = True
) -> sacc.Sacc:
"""Create a new realization of the model."""
"""Create a new realization of the model.
:param sacc_data: The SACC data object containing the covariance matrix
to be read
:param add_noise: If True, add noise to the realization.
:param strict: If True, check that the indices of the realization cover
all the indices of the SACC data object.
:return: The SACC data object containing the new realization
"""
sacc_indices = self.get_sacc_indices()

if add_noise:
Expand Down
13 changes: 10 additions & 3 deletions firecrown/likelihood/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@
class ConstGaussian(GaussFamily):
"""A Gaussian log-likelihood with a constant covariance matrix."""

def compute_loglike(self, tools: ModelingTools):
"""Compute the log-likelihood."""
def compute_loglike(self, tools: ModelingTools) -> float:
"""Compute the log-likelihood.
:params tools: The modeling tools used to compute the likelihood.
:return: The log-likelihood.
"""
return -0.5 * self.compute_chisq(tools)

def make_realization_vector(self) -> np.ndarray:
"""Create a new realization of the model."""
"""Create a new (randomized) realization of the model.
:return: A new realization of the model
"""
theory_vector = self.get_theory_vector()
assert self.cholesky is not None
new_data_vector = theory_vector + np.dot(
Expand Down
Loading

0 comments on commit bd70476

Please sign in to comment.