From ad9a6e1c5da614fed0c0b583acc2b05ff802d4ad Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Tue, 8 Oct 2024 12:07:27 -0500 Subject: [PATCH] More refactoring of TwoPoint statistic (#460) * Move ells to TwoPointTheory * Move thetas to TwoPointTheory * Move mean_ells to TwoPointTheory * Move ells_for_xi to TwoPointTheory * Move cells into TwoPointTheory --- firecrown/likelihood/two_point.py | 92 ++++++++++++++++++------------- tests/test_pt_systematics.py | 1 - 2 files changed, 54 insertions(+), 39 deletions(-) diff --git a/firecrown/likelihood/two_point.py b/firecrown/likelihood/two_point.py index 39ebace6..e9acea39 100644 --- a/firecrown/likelihood/two_point.py +++ b/firecrown/likelihood/two_point.py @@ -101,6 +101,11 @@ def __init__( self.ell_or_theta_max = ell_or_theta_max self.window: None | npt.NDArray[np.float64] = None self.sacc_tracers: None | TracerNames = None + self.ells: None | npt.NDArray[np.int64] = None + self.thetas: None | npt.NDArray[np.float64] = None + self.mean_ells: None | npt.NDArray[np.float64] = None + self.ells_for_xi: None | npt.NDArray[np.int64] = None + self.cells: dict[TracerNames, npt.NDArray[np.float64]] = {} def set_ccl_kind(self, sacc_data_type): """Set the CCL kind for this statistic.""" @@ -111,6 +116,7 @@ def set_ccl_kind(self, sacc_data_type): raise ValueError(f"The SACC data type {sacc_data_type} is not supported!") +# pylint: disable=too-many-public-methods class TwoPoint(Statistic): """A statistic that represents the correlation between two measurements. @@ -221,6 +227,26 @@ def sacc_tracers(self) -> None | TracerNames: """Backwards compatibility for sacc_tracers.""" return self.theory.sacc_tracers + @property + def ells(self) -> None | npt.NDArray[np.int64]: + """Backwards compatibility for ells.""" + return self.theory.ells + + @property + def thetas(self) -> None | npt.NDArray[np.float64]: + """Backwards compatibility for thetas.""" + return self.theory.thetas + + @property + def ells_for_xi(self) -> None | npt.NDArray[np.int64]: + """Backwards compatibility for ells_for_xi.""" + return self.theory.ells_for_xi + + @property + def cells(self): + """Backwards compatibility for cells.""" + return self.theory.cells + def __init__( self, sacc_data_type: str, @@ -241,17 +267,10 @@ def __init__( sacc_data_type, source0, source1, ell_or_theta_min, ell_or_theta_max ) self.data_vector: None | DataVector - self.ells: None | npt.NDArray[np.int64] - self.thetas: None | npt.NDArray[np.float64] - self.mean_ells: None | npt.NDArray[np.float64] - self.ells_for_xi: None | npt.NDArray[np.int64] - self.cells: dict[TracerNames, npt.NDArray[np.float64]] - self._init_empty_default_attribs() if ell_for_xi is not None: self.theory.ell_for_xi_config.update(ell_for_xi) self.theory.ell_or_theta_config = ell_or_theta - self.theory.set_ccl_kind(sacc_data_type) def _init_empty_default_attribs(self): @@ -262,13 +281,6 @@ def _init_empty_default_attribs(self): self.data_vector = None - self.ells = None - self.thetas = None - self.mean_ells = None - self.ells_for_xi = None - - self.cells = {} - @classmethod def from_metadata_index( cls, @@ -327,15 +339,15 @@ def _from_metadata_single( two_point = cls._from_metadata_single_base( metadata, wl_factory, nc_factory ) - two_point.ells = metadata.ells + two_point.theory.ells = metadata.ells two_point.theory.window = metadata.window case TwoPointReal(): two_point = cls._from_metadata_single_base( metadata, wl_factory, nc_factory ) - two_point.thetas = metadata.thetas + two_point.theory.thetas = metadata.thetas two_point.theory.window = None - two_point.ells_for_xi = log_linear_ells( + two_point.theory.ells_for_xi = log_linear_ells( **two_point.theory.ell_for_xi_config ) case _: @@ -543,8 +555,8 @@ def read_real_space(self, sacc_data: sacc.Sacc): self.theory.ell_or_theta_min, self.theory.ell_or_theta_max, ) - self.ells_for_xi = log_linear_ells(**self.theory.ell_for_xi_config) - self.thetas = thetas + self.theory.ells_for_xi = log_linear_ells(**self.theory.ell_for_xi_config) + self.theory.thetas = thetas self.sacc_indices = sacc_indices self.data_vector = DataVector.create(xis) @@ -601,11 +613,11 @@ def read_harmonic_space(self, sacc_data: sacc.Sacc): self.theory.ell_or_theta_min, self.theory.ell_or_theta_max, ) - self.ells = ells + self.theory.ells = ells if self.theory.ell_or_theta_min is not None: - assert np.min(self.ells) >= self.theory.ell_or_theta_min + assert np.min(self.theory.ells) >= self.theory.ell_or_theta_min if self.theory.ell_or_theta_max is not None: - assert np.max(self.ells) <= self.theory.ell_or_theta_max + assert np.max(self.theory.ells) <= self.theory.ell_or_theta_max self.theory.window = window self.sacc_indices = sacc_indices self.data_vector = DataVector.create(Cells) @@ -637,18 +649,18 @@ def compute_theory_vector_real_space(self, tools: ModelingTools) -> TheoryVector scale1 = self.theory.source1.get_scale() assert self.theory.ccl_kind != "cl" - assert self.thetas is not None - assert self.ells_for_xi is not None + assert self.theory.thetas is not None + assert self.theory.ells_for_xi is not None cells_for_xi = self.compute_cells( - self.ells_for_xi, scale0, scale1, tools, tracers0, tracers1 + self.theory.ells_for_xi, scale0, scale1, tools, tracers0, tracers1 ) theory_vector = pyccl.correlation( tools.get_ccl_cosmology(), - ell=self.ells_for_xi, + ell=self.theory.ells_for_xi, C_ell=cells_for_xi, - theta=self.thetas / 60, + theta=self.theory.thetas / 60, type=self.theory.ccl_kind, ) return TheoryVector.create(theory_vector) @@ -668,15 +680,15 @@ def compute_theory_vector_harmonic_space( scale1 = self.theory.source1.get_scale() assert self.theory.ccl_kind == "cl" - assert self.ells is not None + assert self.theory.ells is not None if self.theory.window is not None: ells_for_interpolation = calculate_ells_for_interpolation( - self.ells[0], self.ells[-1] + self.theory.ells[0], self.theory.ells[-1] ) cells_interpolated = self.compute_cells_interpolated( - self.ells, + self.theory.ells, ells_for_interpolation, scale0, scale1, @@ -691,15 +703,17 @@ def compute_theory_vector_harmonic_space( "lb, l -> b", self.theory.window, cells_interpolated ) # We also compute the mean ell value associated with each bin. - self.mean_ells = np.einsum("lb, l -> b", self.theory.window, self.ells) + self.theory.mean_ells = np.einsum( + "lb, l -> b", self.theory.window, self.theory.ells + ) assert self.data_vector is not None return TheoryVector.create(theory_vector) # If we get here, we are working in harmonic space without a window function. - assert self.ells is not None + assert self.theory.ells is not None theory_vector = self.compute_cells( - self.ells, + self.theory.ells, scale0, scale1, tools, @@ -726,17 +740,17 @@ def compute_cells( tracers1: Sequence[Tracer], ) -> npt.NDArray[np.float64]: """Compute the power spectrum for the given ells and tracers.""" - self.cells = {} + self.theory.cells = {} for tracer0 in tracers0: for tracer1 in tracers1: pk_name = f"{tracer0.field}:{tracer1.field}" tn = TracerNames(tracer0.tracer_name, tracer1.tracer_name) - if tn in self.cells: + if tn in self.theory.cells: # Already computed this combination, skipping continue pk = self.calculate_pk(pk_name, tools, tracer0, tracer1) - self.cells[tn] = ( + self.theory.cells[tn] = ( cached_angular_cl( tools.get_ccl_cosmology(), (tracer0.ccl_tracer, tracer1.ccl_tracer), @@ -746,8 +760,10 @@ def compute_cells( * scale0 * scale1 ) - self.cells[TRACER_NAMES_TOTAL] = np.array(sum(self.cells.values())) - theory_vector = self.cells[TRACER_NAMES_TOTAL] + self.theory.cells[TRACER_NAMES_TOTAL] = np.array( + sum(self.theory.cells.values()) + ) + theory_vector = self.theory.cells[TRACER_NAMES_TOTAL] return theory_vector def compute_cells_interpolated( diff --git a/tests/test_pt_systematics.py b/tests/test_pt_systematics.py index 8857c84f..11a4d1fe 100644 --- a/tests/test_pt_systematics.py +++ b/tests/test_pt_systematics.py @@ -161,7 +161,6 @@ def test_pt_systematics(weak_lensing_source, number_counts_source, sacc_data): s1 = likelihood.statistics[1].statistic # del weak_lensing_source.cosmo_hash - s1.cells = {} s1.compute_theory_vector(modeling_tools) assert isinstance(s1, TwoPoint) ells = s1.ells_for_xi