Skip to content

Commit

Permalink
Testing new constructors.
Browse files Browse the repository at this point in the history
Comparing new and old constructors.
  • Loading branch information
vitenti committed Jul 6, 2024
1 parent 68f38f4 commit 6fded03
Show file tree
Hide file tree
Showing 7 changed files with 366 additions and 27 deletions.
71 changes: 65 additions & 6 deletions firecrown/likelihood/two_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
TwoPointXiTheta,
Window,
extract_window_function,
check_two_point_consistence_harmonic,
check_two_point_consistence_real,
)
from firecrown.modeling_tools import ModelingTools
from firecrown.updatable import UpdatableCollection
Expand Down Expand Up @@ -429,9 +431,7 @@ def _set_ccl_kind(self, sacc_data_type):
if self.sacc_data_type in SACC_DATA_TYPE_TO_CCL_KIND:
self.ccl_kind = SACC_DATA_TYPE_TO_CCL_KIND[self.sacc_data_type]
else:
raise ValueError(
f"The SACC data type {sacc_data_type}'%s' is not " f"supported!"
)
raise ValueError(f"The SACC data type {sacc_data_type} is not supported!")

Check warning on line 434 in firecrown/likelihood/two_point.py

View check run for this annotation

Codecov / codecov/patch

firecrown/likelihood/two_point.py#L434

Added line #L434 was not covered by tests

@classmethod
def _from_metadata(
Expand All @@ -457,6 +457,7 @@ def _from_metadata(

def _init_from_cells(self, metadata: TwoPointCells):
"""Initialize the TwoPoint statistic from a TwoPointCells metadata object."""
self.sacc_tracers = metadata.XY.get_tracer_names()
self.ells = metadata.ells
self.window = None
if metadata.Cell is not None:
Expand All @@ -465,27 +466,39 @@ def _init_from_cells(self, metadata: TwoPointCells):

def _init_from_cwindow(self, metadata: TwoPointCWindow):
"""Initialize the TwoPoint statistic from a TwoPointCWindow metadata object."""
self.sacc_tracers = metadata.XY.get_tracer_names()
self.window = metadata.window
if self.window.ells_for_interpolation is None:
self.window.ells_for_interpolation = calculate_ells_for_interpolation(
self.window
)
if metadata.Cell is not None:
self.sacc_indices = metadata.Cell.indices
self.data_vector = DataVector.create(metadata.Cell.data)

def _init_from_xi_theta(self, metadata: TwoPointXiTheta):
"""Initialize the TwoPoint statistic from a TwoPointXiTheta metadata object."""
self.sacc_tracers = metadata.XY.get_tracer_names()
self.thetas = metadata.thetas
self.window = None
self.ells_for_xi = _ell_for_xi(**self.ell_for_xi_config)
if metadata.xis is not None:
self.sacc_indices = metadata.xis.indices
self.data_vector = DataVector.create(metadata.xis.data)

@classmethod
def from_metadata_cells(
def _from_metadata_any(
cls,
metadata: Sequence[TwoPointCells | TwoPointCWindow | TwoPointXiTheta],
wl_factory: WeakLensingFactory | None = None,
nc_factory: NumberCountsFactory | None = None,
) -> UpdatableCollection[TwoPoint]:
"""Create a TwoPoint statistic from a TwoPointCells metadata object."""
"""Create an UpdatableCollection of TwoPoint statistics.
This constructor creates an UpdatableCollection of TwoPoint statistics from a
list of TwoPointCells, TwoPointCWindow or TwoPointXiTheta metadata objects.
The metadata objects are used to initialize the TwoPoint statistics.
"""
two_point_list = [
cls._from_metadata(
sacc_data_type=cell.get_sacc_name(),
Expand All @@ -502,6 +515,52 @@ def from_metadata_cells(

return UpdatableCollection(two_point_list)

@classmethod
def from_metadata_harmonic(
cls,
metadata: Sequence[TwoPointCells | TwoPointCWindow],
wl_factory: WeakLensingFactory | None = None,
nc_factory: NumberCountsFactory | None = None,
check_consistence: bool = False,
) -> UpdatableCollection[TwoPoint]:
"""Create an UpdatableCollection of harmonic space TwoPoint statistics.
This constructor creates an UpdatableCollection of TwoPoint statistics from a
list of TwoPointCells or TwoPointCWindow metadata objects. The metadata objects
are used to initialize the TwoPoint statistics.
:param metadata: The metadata objects to initialize the TwoPoint statistics.
:param wl_factory: The weak lensing factory to use.
:param nc_factory: The number counts factory to use.
:param check_consistence: Whether to check the consistence of the metadata.
"""
if check_consistence:
check_two_point_consistence_harmonic(metadata)

Check warning on line 538 in firecrown/likelihood/two_point.py

View check run for this annotation

Codecov / codecov/patch

firecrown/likelihood/two_point.py#L538

Added line #L538 was not covered by tests
return cls._from_metadata_any(metadata, wl_factory, nc_factory)

@classmethod
def from_metadata_real(
cls,
metadata: Sequence[TwoPointXiTheta],
wl_factory: WeakLensingFactory | None = None,
nc_factory: NumberCountsFactory | None = None,
check_consistence: bool = False,
) -> UpdatableCollection[TwoPoint]:
"""Create an UpdatableCollection of real space TwoPoint statistics.
This constructor creates an UpdatableCollection of TwoPoint statistics from a
list of TwoPointXiTheta metadata objects. The metadata objects are used to
initialize the TwoPoint statistics.
:param metadata: The metadata objects to initialize the TwoPoint statistics.
:param wl_factory: The weak lensing factory to use.
:param nc_factory: The number counts factory to use.
:param check_consistence: Whether to check the consistence of the metadata.
"""
if check_consistence:
check_two_point_consistence_real(metadata)

Check warning on line 561 in firecrown/likelihood/two_point.py

View check run for this annotation

Codecov / codecov/patch

firecrown/likelihood/two_point.py#L561

Added line #L561 was not covered by tests
return cls._from_metadata_any(metadata, wl_factory, nc_factory)

def read_ell_cells(
self, sacc_data_type: str, sacc_data: sacc.Sacc, tracers: TracerNames
) -> (
Expand Down Expand Up @@ -709,7 +768,7 @@ def compute_theory_vector_harmonic_space(
scale1 = self.source1.get_scale()

assert self.ccl_kind == "cl"
assert self.ells is not None
assert (self.ells is not None) or (self.window is not None)

if self.window is not None:
# If a window function is provided, we need to compute the Cl's
Expand Down
4 changes: 4 additions & 0 deletions firecrown/metadata/two_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def __str__(self) -> str:
"""Return a string representation of the TwoPointXY object."""
return f"({self.x.bin_name}, {self.y.bin_name})"

Check warning on line 134 in firecrown/metadata/two_point.py

View check run for this annotation

Codecov / codecov/patch

firecrown/metadata/two_point.py#L134

Added line #L134 was not covered by tests

def get_tracer_names(self) -> TracerNames:
"""Return the TracerNames object for the TwoPointXY object."""
return TracerNames(self.x.bin_name, self.y.bin_name)


@dataclass(frozen=True, kw_only=True)
class TwoPointMeasurement(YAMLSerializable):
Expand Down
24 changes: 23 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from firecrown.connector.mapping import MappingCosmoSIS, mapping_builder
from firecrown.modeling_tools import ModelingTools
from firecrown.metadata.two_point import TracerNames
import firecrown.likelihood.weak_lensing as wl
import firecrown.likelihood.number_counts as nc


def pytest_addoption(parser):
Expand Down Expand Up @@ -148,6 +150,8 @@ def fixture_tools_with_vanilla_cosmology():
result.update(ParamsMap())
result.prepare(pyccl.CosmologyVanillaLCDM())

return result


@pytest.fixture(name="cluster_sacc_data")
def fixture_cluster_sacc_data() -> sacc.Sacc:
Expand Down Expand Up @@ -437,7 +441,10 @@ def fixture_sacc_galaxy_cwindows():

tracers: dict[str, tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]] = {}
tracer_pairs: dict[
TracerNames, tuple[str, npt.NDArray[np.int64], npt.NDArray[np.float64]]
TracerNames,
tuple[
str, npt.NDArray[np.int64], npt.NDArray[np.float64], sacc.BandpowerWindow
],
] = {}

for i, mn in enumerate(src_bins_centers):
Expand Down Expand Up @@ -471,6 +478,7 @@ def fixture_sacc_galaxy_cwindows():
"galaxy_shear_cl_ee",
ells,
Cells,
window,
)

for i, j in upper_triangle_indices(len(lens_bins_centers)):
Expand All @@ -493,6 +501,7 @@ def fixture_sacc_galaxy_cwindows():
"galaxy_density_cl",
ells,
Cells,
window,
)

for i, j in product(range(len(src_bins_centers)), range(len(lens_bins_centers))):
Expand All @@ -515,6 +524,7 @@ def fixture_sacc_galaxy_cwindows():
"galaxy_shearDensity_cl_e",
ells,
Cells,
window,
)

sacc_data.add_covariance(np.identity(len(sacc_data)) * 0.01)
Expand Down Expand Up @@ -654,3 +664,15 @@ def fixture_sacc_galaxy_cells_src0_src0_no_window():
sacc_data.add_covariance(cov)

return sacc_data, z, dndz


@pytest.fixture(name="wl_factory")
def make_wl_factory():
"""Generate a WeakLensingFactory object."""
return wl.WeakLensingFactory(per_bin_systematics=[], global_systematics=[])


@pytest.fixture(name="nc_factory")
def make_nc_factory():
"""Generate a NumberCountsFactory object."""
return nc.NumberCountsFactory(per_bin_systematics=[], global_systematics=[])
65 changes: 49 additions & 16 deletions tests/metadata/test_metadata_two_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
Window,
)
from firecrown.likelihood.source import SourceGalaxy
import firecrown.likelihood.weak_lensing as wl
import firecrown.likelihood.number_counts as nc
from firecrown.likelihood.two_point import TwoPoint


Expand Down Expand Up @@ -121,18 +119,6 @@ def make_two_point_cwindow_1(
return two_point


@pytest.fixture(name="wl_factory")
def make_wl_factory():
"""Generate a WeakLensingFactory object."""
return wl.WeakLensingFactory(per_bin_systematics=[], global_systematics=[])


@pytest.fixture(name="nc_factory")
def make_nc_factory():
"""Generate a NumberCountsFactory object."""
return nc.NumberCountsFactory(per_bin_systematics=[], global_systematics=[])


def test_order_enums():
assert compare_enums(CMB.CONVERGENCE, Clusters.COUNTS) < 0
assert compare_enums(Clusters.COUNTS, CMB.CONVERGENCE) > 0
Expand Down Expand Up @@ -746,7 +732,7 @@ def test_two_point_from_metadata_cells(
ells = np.array(np.linspace(0, 100, 100), dtype=np.int64)
xy = TwoPointXY(x=harmonic_bin_1, y=harmonic_bin_2)
cells = TwoPointCells(ells=ells, XY=xy)
two_point = TwoPoint.from_metadata_cells([cells], wl_factory, nc_factory).pop()
two_point = TwoPoint.from_metadata_harmonic([cells], wl_factory, nc_factory).pop()

assert two_point is not None
assert isinstance(two_point, TwoPoint)
Expand All @@ -762,6 +748,53 @@ def test_two_point_from_metadata_cells(
assert_array_equal(two_point.source1.tracer_args.dndz, harmonic_bin_2.dndz)


def test_two_point_from_metadata_cwindow(two_point_cwindow_1, wl_factory, nc_factory):
two_point = TwoPoint.from_metadata_harmonic(
[two_point_cwindow_1], wl_factory, nc_factory
).pop()

assert two_point is not None
assert isinstance(two_point, TwoPoint)
assert two_point.sacc_data_type == two_point_cwindow_1.get_sacc_name()

assert isinstance(two_point.source0, SourceGalaxy)
assert isinstance(two_point.source1, SourceGalaxy)

assert_array_equal(two_point.source0.tracer_args.z, two_point_cwindow_1.XY.x.z)
assert_array_equal(two_point.source1.tracer_args.z, two_point_cwindow_1.XY.y.z)

assert_array_equal(
two_point.source0.tracer_args.dndz, two_point_cwindow_1.XY.x.dndz
)
assert_array_equal(
two_point.source1.tracer_args.dndz, two_point_cwindow_1.XY.y.dndz
)


def test_two_point_from_metadata_xi_theta(
real_bin_1, real_bin_2, wl_factory, nc_factory
):
theta = np.array(np.linspace(0, 100, 100))
xy = TwoPointXY(x=real_bin_1, y=real_bin_2)
xi_theta = TwoPointXiTheta(XY=xy, thetas=theta)
if xi_theta.get_sacc_name() == "galaxy_shear_xi_tt":
return
two_point = TwoPoint.from_metadata_real([xi_theta], wl_factory, nc_factory).pop()

assert two_point is not None
assert isinstance(two_point, TwoPoint)
assert two_point.sacc_data_type == xi_theta.get_sacc_name()

assert isinstance(two_point.source0, SourceGalaxy)
assert isinstance(two_point.source1, SourceGalaxy)

assert_array_equal(two_point.source0.tracer_args.z, real_bin_1.z)
assert_array_equal(two_point.source1.tracer_args.z, real_bin_2.z)

assert_array_equal(two_point.source0.tracer_args.dndz, real_bin_1.dndz)
assert_array_equal(two_point.source1.tracer_args.dndz, real_bin_2.dndz)


def test_two_point_from_metadata_cells_unsupported_type(wl_factory, nc_factory):
ells = np.array(np.linspace(0, 100, 100), dtype=np.int64)
x = InferredGalaxyZDist(
Expand All @@ -782,4 +815,4 @@ def test_two_point_from_metadata_cells_unsupported_type(wl_factory, nc_factory):
ValueError,
match="Measurement .* not supported!",
):
TwoPoint.from_metadata_cells([cells], wl_factory, nc_factory)
TwoPoint.from_metadata_harmonic([cells], wl_factory, nc_factory)
Loading

0 comments on commit 6fded03

Please sign in to comment.