Skip to content

Commit

Permalink
Add a save_to_sacc utility function (#390)
Browse files Browse the repository at this point in the history
* add save_to_sacc_utility

Co-authored-by: Marc Paterno <paterno@fnal.gov>
  • Loading branch information
tilmantroester and marcpaterno authored Feb 22, 2024
1 parent 946e50f commit 73af686
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 21 deletions.
51 changes: 31 additions & 20 deletions firecrown/likelihood/gauss_family/gauss_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ...modeling_tools import ModelingTools
from ...updatable import UpdatableCollection
from .statistic.statistic import Statistic, GuardedStatistic
from ...utils import save_to_sacc


class State(Enum):
Expand Down Expand Up @@ -335,6 +336,29 @@ def compute_chisq(self, tools: ModelingTools) -> float:
chisq = np.dot(x, x)
return chisq

@enforce_states(
initial=[State.READY, State.UPDATED, State.COMPUTED],
failure_message="read() must be called before get_sacc_indices()",
)
def get_sacc_indices(
self, statistic: Union[Statistic, list[Statistic], None] = None
) -> npt.NDArray[np.int64]:
"""Get the SACC indices of the statistic or list of statistics. If no
statistic is given, get the indices of all statistics of the likelihood."""
if statistic is None:
statistic = [stat.statistic for stat in self.statistics]
if isinstance(statistic, Statistic):
statistic = [statistic]

sacc_indices_list = []
for stat in statistic:
assert stat.sacc_indices is not None
sacc_indices_list.append(stat.sacc_indices.copy())

sacc_indices = np.concatenate(sacc_indices_list)

return sacc_indices

@enforce_states(
initial=State.COMPUTED,
failure_message="compute_theory_vector() must be called before "
Expand All @@ -343,31 +367,18 @@ def compute_chisq(self, tools: ModelingTools) -> float:
def make_realization(
self, sacc_data: sacc.Sacc, add_noise: bool = True, strict: bool = True
) -> sacc.Sacc:
new_sacc = sacc_data.copy()

sacc_indices_list = []
for stat in self.statistics:
assert stat.statistic.sacc_indices is not None
sacc_indices_list.append(stat.statistic.sacc_indices.copy())

sacc_indices = np.concatenate(sacc_indices_list)
sacc_indices = self.get_sacc_indices()

if add_noise:
new_data_vector = self.make_realization_vector()
else:
new_data_vector = self.get_theory_vector()

assert len(sacc_indices) == len(new_data_vector)

if strict:
if set(sacc_indices.tolist()) != set(sacc_data.indices()):
raise RuntimeError(
"The predicted data does not cover all the data in the "
"sacc object. To write only the calculated predictions, "
"set strict=False."
)

for prediction_idx, sacc_idx in enumerate(sacc_indices):
new_sacc.data[sacc_idx].value = new_data_vector[prediction_idx]
new_sacc = save_to_sacc(
sacc_data=sacc_data,
data_vector=new_data_vector,
indices=sacc_indices,
strict=strict,
)

return new_sacc
55 changes: 55 additions & 0 deletions firecrown/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
"""Some utility functions for patterns common in Firecrown.
"""

from __future__ import annotations

import numpy as np
import numpy.typing as npt

import sacc


def upper_triangle_indices(n: int):
"""generator that yields a sequence of tuples that carry the indices for an
Expand All @@ -13,3 +20,51 @@ def upper_triangle_indices(n: int):
for i in range(n):
for j in range(i, n):
yield i, j


def save_to_sacc(
sacc_data: sacc.Sacc,
data_vector: npt.NDArray[np.float64],
indices: npt.NDArray[np.int64],
strict: bool = True,
) -> sacc.Sacc:
"""Save a data vector into a (new) SACC object, copied from `sacc_data`.
Note that the original object `sacc_data` is not modified. Its contents are
copied into a new object, and the new information is put into that copy,
which is returned by this method.
Arguments
---------
sacc_data: sacc.Sacc
SACC object to be copied. It is not modified.
data_vector: np.ndarray[float]
Data vector to be saved to the new copy of `sacc_data`.
indices: np.ndarray[int]
SACC indices where the data vector should be written.
strict: bool
Whether to check if the data vector covers all the data already present
in the sacc_data.
Returns
-------
new_sacc: sacc.Sacc
A copy of `sacc_data`, with data at `indices` replaced with `data_vector`.
"""

assert len(indices) == len(data_vector)

new_sacc = sacc_data.copy()

if strict:
if set(indices.tolist()) != set(sacc_data.indices()):
raise RuntimeError(
"The data to be saved does not cover all the data in the "
"sacc object. To write only the calculated predictions, "
"set strict=False."
)

for data_idx, sacc_idx in enumerate(indices):
new_sacc.data[sacc_idx].value = data_vector[data_idx]

return new_sacc
29 changes: 29 additions & 0 deletions tests/likelihood/gauss_family/test_const_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,32 @@ def test_make_realization_no_noise(
new_likelihood.update(params)

assert_allclose(new_likelihood.get_data_vector(), likelihood.get_theory_vector())


def test_get_sacc_indices(
trivial_stats,
sacc_data_for_trivial_stat: sacc.Sacc,
):
likelihood = ConstGaussian(statistics=trivial_stats)
likelihood.read(sacc_data_for_trivial_stat)

idx = likelihood.get_sacc_indices()

assert all(
idx
== np.concatenate(
[stat.statistic.sacc_indices for stat in likelihood.statistics]
)
)


def test_get_sacc_indices_single_stat(
trivial_stats,
sacc_data_for_trivial_stat: sacc.Sacc,
):
likelihood = ConstGaussian(statistics=trivial_stats)
likelihood.read(sacc_data_for_trivial_stat)

idx = likelihood.get_sacc_indices(statistic=likelihood.statistics[0].statistic)

assert all(idx == likelihood.statistics[0].statistic.sacc_indices)
35 changes: 34 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
Tests for the firecrown.utils modle.
"""

from firecrown.utils import upper_triangle_indices
import pytest
import numpy as np

from firecrown.utils import upper_triangle_indices, save_to_sacc


def test_upper_triangle_indices_nonzero():
Expand All @@ -13,3 +16,33 @@ def test_upper_triangle_indices_nonzero():
def test_upper_triangle_indices_zero():
indices = list(upper_triangle_indices(0))
assert not indices


def test_save_to_sacc(trivial_stats, sacc_data_for_trivial_stat):
stat = trivial_stats[0]
stat.read(sacc_data_for_trivial_stat)
idx = np.arange(stat.count)
new_data_vector = 3 * stat.get_data_vector()[idx]

new_sacc = save_to_sacc(
sacc_data=sacc_data_for_trivial_stat,
data_vector=new_data_vector,
indices=idx,
strict=True,
)
assert all(new_sacc.data[i].value == d for i, d in zip(idx, new_data_vector))


def test_save_to_sacc_strict_fail(trivial_stats, sacc_data_for_trivial_stat):
stat = trivial_stats[0]
stat.read(sacc_data_for_trivial_stat)
idx = np.arange(stat.count - 1)
new_data_vector = stat.get_data_vector()[idx]

with pytest.raises(RuntimeError):
_ = save_to_sacc(
sacc_data=sacc_data_for_trivial_stat,
data_vector=new_data_vector,
indices=idx,
strict=True,
)

0 comments on commit 73af686

Please sign in to comment.