diff --git a/firecrown/likelihood/gauss_family/gauss_family.py b/firecrown/likelihood/gauss_family/gauss_family.py index e5252358..c0b39580 100644 --- a/firecrown/likelihood/gauss_family/gauss_family.py +++ b/firecrown/likelihood/gauss_family/gauss_family.py @@ -27,6 +27,7 @@ from ...modeling_tools import ModelingTools from ...updatable import UpdatableCollection from .statistic.statistic import Statistic, GuardedStatistic +from ...utils import save_to_sacc class State(Enum): @@ -335,6 +336,29 @@ def compute_chisq(self, tools: ModelingTools) -> float: chisq = np.dot(x, x) return chisq + @enforce_states( + initial=[State.READY, State.UPDATED, State.COMPUTED], + failure_message="read() must be called before get_sacc_indices()", + ) + def get_sacc_indices( + self, statistic: Union[Statistic, list[Statistic], None] = None + ) -> npt.NDArray[np.int64]: + """Get the SACC indices of the statistic or list of statistics. If no + statistic is given, get the indices of all statistics of the likelihood.""" + if statistic is None: + statistic = [stat.statistic for stat in self.statistics] + if isinstance(statistic, Statistic): + statistic = [statistic] + + sacc_indices_list = [] + for stat in statistic: + assert stat.sacc_indices is not None + sacc_indices_list.append(stat.sacc_indices.copy()) + + sacc_indices = np.concatenate(sacc_indices_list) + + return sacc_indices + @enforce_states( initial=State.COMPUTED, failure_message="compute_theory_vector() must be called before " @@ -343,31 +367,18 @@ def compute_chisq(self, tools: ModelingTools) -> float: def make_realization( self, sacc_data: sacc.Sacc, add_noise: bool = True, strict: bool = True ) -> sacc.Sacc: - new_sacc = sacc_data.copy() - - sacc_indices_list = [] - for stat in self.statistics: - assert stat.statistic.sacc_indices is not None - sacc_indices_list.append(stat.statistic.sacc_indices.copy()) - - sacc_indices = np.concatenate(sacc_indices_list) + sacc_indices = self.get_sacc_indices() if add_noise: new_data_vector = self.make_realization_vector() else: new_data_vector = self.get_theory_vector() - assert len(sacc_indices) == len(new_data_vector) - - if strict: - if set(sacc_indices.tolist()) != set(sacc_data.indices()): - raise RuntimeError( - "The predicted data does not cover all the data in the " - "sacc object. To write only the calculated predictions, " - "set strict=False." - ) - - for prediction_idx, sacc_idx in enumerate(sacc_indices): - new_sacc.data[sacc_idx].value = new_data_vector[prediction_idx] + new_sacc = save_to_sacc( + sacc_data=sacc_data, + data_vector=new_data_vector, + indices=sacc_indices, + strict=strict, + ) return new_sacc diff --git a/firecrown/utils.py b/firecrown/utils.py index 8c56dbb5..7937f514 100644 --- a/firecrown/utils.py +++ b/firecrown/utils.py @@ -1,6 +1,13 @@ """Some utility functions for patterns common in Firecrown. """ +from __future__ import annotations + +import numpy as np +import numpy.typing as npt + +import sacc + def upper_triangle_indices(n: int): """generator that yields a sequence of tuples that carry the indices for an @@ -13,3 +20,51 @@ def upper_triangle_indices(n: int): for i in range(n): for j in range(i, n): yield i, j + + +def save_to_sacc( + sacc_data: sacc.Sacc, + data_vector: npt.NDArray[np.float64], + indices: npt.NDArray[np.int64], + strict: bool = True, +) -> sacc.Sacc: + """Save a data vector into a (new) SACC object, copied from `sacc_data`. + + Note that the original object `sacc_data` is not modified. Its contents are + copied into a new object, and the new information is put into that copy, + which is returned by this method. + + Arguments + --------- + sacc_data: sacc.Sacc + SACC object to be copied. It is not modified. + data_vector: np.ndarray[float] + Data vector to be saved to the new copy of `sacc_data`. + indices: np.ndarray[int] + SACC indices where the data vector should be written. + strict: bool + Whether to check if the data vector covers all the data already present + in the sacc_data. + + Returns + ------- + new_sacc: sacc.Sacc + A copy of `sacc_data`, with data at `indices` replaced with `data_vector`. + """ + + assert len(indices) == len(data_vector) + + new_sacc = sacc_data.copy() + + if strict: + if set(indices.tolist()) != set(sacc_data.indices()): + raise RuntimeError( + "The data to be saved does not cover all the data in the " + "sacc object. To write only the calculated predictions, " + "set strict=False." + ) + + for data_idx, sacc_idx in enumerate(indices): + new_sacc.data[sacc_idx].value = data_vector[data_idx] + + return new_sacc diff --git a/tests/likelihood/gauss_family/test_const_gaussian.py b/tests/likelihood/gauss_family/test_const_gaussian.py index d176dab5..802526a3 100644 --- a/tests/likelihood/gauss_family/test_const_gaussian.py +++ b/tests/likelihood/gauss_family/test_const_gaussian.py @@ -420,3 +420,32 @@ def test_make_realization_no_noise( new_likelihood.update(params) assert_allclose(new_likelihood.get_data_vector(), likelihood.get_theory_vector()) + + +def test_get_sacc_indices( + trivial_stats, + sacc_data_for_trivial_stat: sacc.Sacc, +): + likelihood = ConstGaussian(statistics=trivial_stats) + likelihood.read(sacc_data_for_trivial_stat) + + idx = likelihood.get_sacc_indices() + + assert all( + idx + == np.concatenate( + [stat.statistic.sacc_indices for stat in likelihood.statistics] + ) + ) + + +def test_get_sacc_indices_single_stat( + trivial_stats, + sacc_data_for_trivial_stat: sacc.Sacc, +): + likelihood = ConstGaussian(statistics=trivial_stats) + likelihood.read(sacc_data_for_trivial_stat) + + idx = likelihood.get_sacc_indices(statistic=likelihood.statistics[0].statistic) + + assert all(idx == likelihood.statistics[0].statistic.sacc_indices) diff --git a/tests/test_utils.py b/tests/test_utils.py index 5086a80d..89d50ea2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,10 @@ Tests for the firecrown.utils modle. """ -from firecrown.utils import upper_triangle_indices +import pytest +import numpy as np + +from firecrown.utils import upper_triangle_indices, save_to_sacc def test_upper_triangle_indices_nonzero(): @@ -13,3 +16,33 @@ def test_upper_triangle_indices_nonzero(): def test_upper_triangle_indices_zero(): indices = list(upper_triangle_indices(0)) assert not indices + + +def test_save_to_sacc(trivial_stats, sacc_data_for_trivial_stat): + stat = trivial_stats[0] + stat.read(sacc_data_for_trivial_stat) + idx = np.arange(stat.count) + new_data_vector = 3 * stat.get_data_vector()[idx] + + new_sacc = save_to_sacc( + sacc_data=sacc_data_for_trivial_stat, + data_vector=new_data_vector, + indices=idx, + strict=True, + ) + assert all(new_sacc.data[i].value == d for i, d in zip(idx, new_data_vector)) + + +def test_save_to_sacc_strict_fail(trivial_stats, sacc_data_for_trivial_stat): + stat = trivial_stats[0] + stat.read(sacc_data_for_trivial_stat) + idx = np.arange(stat.count - 1) + new_data_vector = stat.get_data_vector()[idx] + + with pytest.raises(RuntimeError): + _ = save_to_sacc( + sacc_data=sacc_data_for_trivial_stat, + data_vector=new_data_vector, + indices=idx, + strict=True, + )