Skip to content

Commit

Permalink
More refactoring of TwoPoint statistic (#460)
Browse files Browse the repository at this point in the history
* Move ells to TwoPointTheory
* Move thetas to TwoPointTheory
* Move mean_ells to TwoPointTheory
* Move ells_for_xi to TwoPointTheory
* Move cells into TwoPointTheory
  • Loading branch information
marcpaterno authored Oct 8, 2024
1 parent a26cc7a commit ad9a6e1
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 39 deletions.
92 changes: 54 additions & 38 deletions firecrown/likelihood/two_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def __init__(
self.ell_or_theta_max = ell_or_theta_max
self.window: None | npt.NDArray[np.float64] = None
self.sacc_tracers: None | TracerNames = None
self.ells: None | npt.NDArray[np.int64] = None
self.thetas: None | npt.NDArray[np.float64] = None
self.mean_ells: None | npt.NDArray[np.float64] = None
self.ells_for_xi: None | npt.NDArray[np.int64] = None
self.cells: dict[TracerNames, npt.NDArray[np.float64]] = {}

def set_ccl_kind(self, sacc_data_type):
"""Set the CCL kind for this statistic."""
Expand All @@ -111,6 +116,7 @@ def set_ccl_kind(self, sacc_data_type):
raise ValueError(f"The SACC data type {sacc_data_type} is not supported!")


# pylint: disable=too-many-public-methods
class TwoPoint(Statistic):
"""A statistic that represents the correlation between two measurements.
Expand Down Expand Up @@ -221,6 +227,26 @@ def sacc_tracers(self) -> None | TracerNames:
"""Backwards compatibility for sacc_tracers."""
return self.theory.sacc_tracers

@property
def ells(self) -> None | npt.NDArray[np.int64]:
"""Backwards compatibility for ells."""
return self.theory.ells

@property
def thetas(self) -> None | npt.NDArray[np.float64]:
"""Backwards compatibility for thetas."""
return self.theory.thetas

@property
def ells_for_xi(self) -> None | npt.NDArray[np.int64]:
"""Backwards compatibility for ells_for_xi."""
return self.theory.ells_for_xi

@property
def cells(self):
"""Backwards compatibility for cells."""
return self.theory.cells

def __init__(
self,
sacc_data_type: str,
Expand All @@ -241,17 +267,10 @@ def __init__(
sacc_data_type, source0, source1, ell_or_theta_min, ell_or_theta_max
)
self.data_vector: None | DataVector
self.ells: None | npt.NDArray[np.int64]
self.thetas: None | npt.NDArray[np.float64]
self.mean_ells: None | npt.NDArray[np.float64]
self.ells_for_xi: None | npt.NDArray[np.int64]
self.cells: dict[TracerNames, npt.NDArray[np.float64]]

self._init_empty_default_attribs()
if ell_for_xi is not None:
self.theory.ell_for_xi_config.update(ell_for_xi)
self.theory.ell_or_theta_config = ell_or_theta

self.theory.set_ccl_kind(sacc_data_type)

def _init_empty_default_attribs(self):
Expand All @@ -262,13 +281,6 @@ def _init_empty_default_attribs(self):

self.data_vector = None

self.ells = None
self.thetas = None
self.mean_ells = None
self.ells_for_xi = None

self.cells = {}

@classmethod
def from_metadata_index(
cls,
Expand Down Expand Up @@ -327,15 +339,15 @@ def _from_metadata_single(
two_point = cls._from_metadata_single_base(
metadata, wl_factory, nc_factory
)
two_point.ells = metadata.ells
two_point.theory.ells = metadata.ells
two_point.theory.window = metadata.window
case TwoPointReal():
two_point = cls._from_metadata_single_base(
metadata, wl_factory, nc_factory
)
two_point.thetas = metadata.thetas
two_point.theory.thetas = metadata.thetas
two_point.theory.window = None
two_point.ells_for_xi = log_linear_ells(
two_point.theory.ells_for_xi = log_linear_ells(
**two_point.theory.ell_for_xi_config
)
case _:
Expand Down Expand Up @@ -543,8 +555,8 @@ def read_real_space(self, sacc_data: sacc.Sacc):
self.theory.ell_or_theta_min,
self.theory.ell_or_theta_max,
)
self.ells_for_xi = log_linear_ells(**self.theory.ell_for_xi_config)
self.thetas = thetas
self.theory.ells_for_xi = log_linear_ells(**self.theory.ell_for_xi_config)
self.theory.thetas = thetas
self.sacc_indices = sacc_indices
self.data_vector = DataVector.create(xis)

Expand Down Expand Up @@ -601,11 +613,11 @@ def read_harmonic_space(self, sacc_data: sacc.Sacc):
self.theory.ell_or_theta_min,
self.theory.ell_or_theta_max,
)
self.ells = ells
self.theory.ells = ells
if self.theory.ell_or_theta_min is not None:
assert np.min(self.ells) >= self.theory.ell_or_theta_min
assert np.min(self.theory.ells) >= self.theory.ell_or_theta_min
if self.theory.ell_or_theta_max is not None:
assert np.max(self.ells) <= self.theory.ell_or_theta_max
assert np.max(self.theory.ells) <= self.theory.ell_or_theta_max
self.theory.window = window
self.sacc_indices = sacc_indices
self.data_vector = DataVector.create(Cells)
Expand Down Expand Up @@ -637,18 +649,18 @@ def compute_theory_vector_real_space(self, tools: ModelingTools) -> TheoryVector
scale1 = self.theory.source1.get_scale()

assert self.theory.ccl_kind != "cl"
assert self.thetas is not None
assert self.ells_for_xi is not None
assert self.theory.thetas is not None
assert self.theory.ells_for_xi is not None

cells_for_xi = self.compute_cells(
self.ells_for_xi, scale0, scale1, tools, tracers0, tracers1
self.theory.ells_for_xi, scale0, scale1, tools, tracers0, tracers1
)

theory_vector = pyccl.correlation(
tools.get_ccl_cosmology(),
ell=self.ells_for_xi,
ell=self.theory.ells_for_xi,
C_ell=cells_for_xi,
theta=self.thetas / 60,
theta=self.theory.thetas / 60,
type=self.theory.ccl_kind,
)
return TheoryVector.create(theory_vector)
Expand All @@ -668,15 +680,15 @@ def compute_theory_vector_harmonic_space(
scale1 = self.theory.source1.get_scale()

assert self.theory.ccl_kind == "cl"
assert self.ells is not None
assert self.theory.ells is not None

if self.theory.window is not None:
ells_for_interpolation = calculate_ells_for_interpolation(
self.ells[0], self.ells[-1]
self.theory.ells[0], self.theory.ells[-1]
)

cells_interpolated = self.compute_cells_interpolated(
self.ells,
self.theory.ells,
ells_for_interpolation,
scale0,
scale1,
Expand All @@ -691,15 +703,17 @@ def compute_theory_vector_harmonic_space(
"lb, l -> b", self.theory.window, cells_interpolated
)
# We also compute the mean ell value associated with each bin.
self.mean_ells = np.einsum("lb, l -> b", self.theory.window, self.ells)
self.theory.mean_ells = np.einsum(
"lb, l -> b", self.theory.window, self.theory.ells
)

assert self.data_vector is not None
return TheoryVector.create(theory_vector)

# If we get here, we are working in harmonic space without a window function.
assert self.ells is not None
assert self.theory.ells is not None
theory_vector = self.compute_cells(
self.ells,
self.theory.ells,
scale0,
scale1,
tools,
Expand All @@ -726,17 +740,17 @@ def compute_cells(
tracers1: Sequence[Tracer],
) -> npt.NDArray[np.float64]:
"""Compute the power spectrum for the given ells and tracers."""
self.cells = {}
self.theory.cells = {}
for tracer0 in tracers0:
for tracer1 in tracers1:
pk_name = f"{tracer0.field}:{tracer1.field}"
tn = TracerNames(tracer0.tracer_name, tracer1.tracer_name)
if tn in self.cells:
if tn in self.theory.cells:
# Already computed this combination, skipping
continue
pk = self.calculate_pk(pk_name, tools, tracer0, tracer1)

self.cells[tn] = (
self.theory.cells[tn] = (
cached_angular_cl(
tools.get_ccl_cosmology(),
(tracer0.ccl_tracer, tracer1.ccl_tracer),
Expand All @@ -746,8 +760,10 @@ def compute_cells(
* scale0
* scale1
)
self.cells[TRACER_NAMES_TOTAL] = np.array(sum(self.cells.values()))
theory_vector = self.cells[TRACER_NAMES_TOTAL]
self.theory.cells[TRACER_NAMES_TOTAL] = np.array(
sum(self.theory.cells.values())
)
theory_vector = self.theory.cells[TRACER_NAMES_TOTAL]
return theory_vector

def compute_cells_interpolated(
Expand Down
1 change: 0 additions & 1 deletion tests/test_pt_systematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ def test_pt_systematics(weak_lensing_source, number_counts_source, sacc_data):

s1 = likelihood.statistics[1].statistic
# del weak_lensing_source.cosmo_hash
s1.cells = {}
s1.compute_theory_vector(modeling_tools)
assert isinstance(s1, TwoPoint)
ells = s1.ells_for_xi
Expand Down

0 comments on commit ad9a6e1

Please sign in to comment.