Skip to content

Commit

Permalink
Restructuring metadata objects (#448)
Browse files Browse the repository at this point in the history
* First move of files, and adjustments
* Remove of Window and related classes
* Merge of TwoPointCell and TwoPointCWindow into TwoPointHarmonic.
* Updated tests for the new object structure.
* Finished inversion of data and metadata.
* Adding tests for uncovered lines.
* Updating tutorials to use new structure.
* Testing window function application.
* Removing left-over arrays from TwoPoint*Index dataclasses.
* Putting ids on parametrized tests.

---------

Co-authored-by: Marc Paterno <paterno@fnal.gov>
  • Loading branch information
vitenti and marcpaterno authored Sep 10, 2024
1 parent 5b28284 commit 4eef092
Show file tree
Hide file tree
Showing 25 changed files with 1,464 additions and 1,289 deletions.
200 changes: 200 additions & 0 deletions firecrown/data_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""This module deals with two-point data functions.
It contains functions to manipulate two-point data objects.
"""

import hashlib
from typing import Sequence

import sacc

from firecrown.metadata_types import (
TwoPointHarmonic,
TwoPointReal,
)
from firecrown.metadata_functions import (
extract_all_tracers_inferred_galaxy_zdists,
extract_window_function,
extract_all_harmonic_metadata_indices,
extract_all_real_metadata_indices,
make_two_point_xy,
)
from firecrown.data_types import TwoPointMeasurement


def extract_all_harmonic_data(
sacc_data: sacc.Sacc,
allowed_data_type: None | list[str] = None,
include_maybe_types=False,
) -> list[TwoPointMeasurement]:
"""Extract the two-point function metadata and data from a sacc file."""
inferred_galaxy_zdists_dict = {
igz.bin_name: igz
for igz in extract_all_tracers_inferred_galaxy_zdists(
sacc_data, include_maybe_types=include_maybe_types
)
}

if sacc_data.covariance is None or sacc_data.covariance.dense is None:
raise ValueError("The SACC object does not have a covariance matrix.")
cov_hash = hashlib.sha256(sacc_data.covariance.dense).hexdigest()

tpms: list[TwoPointMeasurement] = []
for cell_index in extract_all_harmonic_metadata_indices(
sacc_data, allowed_data_type
):
t1, t2 = cell_index["tracer_names"]
dt = cell_index["data_type"]

ells, Cells, indices = sacc_data.get_ell_cl(
data_type=dt, tracer1=t1, tracer2=t2, return_cov=False, return_ind=True
)

replacement_ells, weights = extract_window_function(sacc_data, indices)
if replacement_ells is not None:
ells = replacement_ells

tpms.append(
TwoPointMeasurement(
data=Cells,
indices=indices,
covariance_name=cov_hash,
metadata=TwoPointHarmonic(
XY=make_two_point_xy(
inferred_galaxy_zdists_dict, cell_index["tracer_names"], dt
),
window=weights,
ells=ells,
),
),
)

return tpms


# Extracting the two-point function metadata and data from a sacc file


def extract_all_real_data(
sacc_data: sacc.Sacc,
allowed_data_type: None | list[str] = None,
include_maybe_types=False,
) -> list[TwoPointMeasurement]:
"""Extract the two-point function metadata and data from a sacc file."""
inferred_galaxy_zdists_dict = {
igz.bin_name: igz
for igz in extract_all_tracers_inferred_galaxy_zdists(
sacc_data, include_maybe_types=include_maybe_types
)
}

cov_hash = hashlib.sha256(sacc_data.covariance.dense).hexdigest()

result: list[TwoPointMeasurement] = []
for real_index in extract_all_real_metadata_indices(sacc_data, allowed_data_type):
t1, t2 = real_index["tracer_names"]
dt = real_index["data_type"]

thetas, Xis, indices = sacc_data.get_theta_xi(
data_type=dt, tracer1=t1, tracer2=t2, return_cov=False, return_ind=True
)

result.append(
TwoPointMeasurement(
data=Xis,
indices=indices,
covariance_name=cov_hash,
metadata=TwoPointReal(
XY=make_two_point_xy(
inferred_galaxy_zdists_dict, real_index["tracer_names"], dt
),
thetas=thetas,
),
)
)

return result


def check_two_point_consistence_harmonic(
two_point_harmonics: Sequence[TwoPointMeasurement],
) -> None:
"""Check the indices of the harmonic-space two-point functions.
Make sure the indices of the harmonic-space two-point functions are consistent.
"""
all_indices_set: set[int] = set()
index_set_list = []
cov_name: None | str = None

for harmonic in two_point_harmonics:
if not harmonic.is_harmonic():
raise ValueError(
f"The metadata of the TwoPointMeasurement {harmonic} is not "
f"a measurement of TwoPointHarmonic."
)
if cov_name is None:
cov_name = harmonic.covariance_name
elif cov_name != harmonic.covariance_name:
raise ValueError(
f"The TwoPointHarmonic {harmonic} has a different covariance "
f"name {harmonic.covariance_name} than the previous "
f"TwoPointHarmonic {cov_name}."
)
index_set = set(harmonic.indices)
index_set_list.append(index_set)
if len(index_set) != len(harmonic.indices):
raise ValueError(
f"The indices of the TwoPointHarmonic {harmonic} are not unique."
)

if all_indices_set & index_set:
for i, index_set_a in enumerate(index_set_list):
if index_set_a & index_set:
raise ValueError(
f"The indices of the TwoPointHarmonic "
f"{two_point_harmonics[i]} and {harmonic} overlap."
)
all_indices_set.update(index_set)


def check_two_point_consistence_real(
two_point_reals: Sequence[TwoPointMeasurement],
) -> None:
"""Check the indices of the real-space two-point functions.
Make sure the indices of the real-space two-point functions are consistent.
"""
all_indices_set: set[int] = set()
index_set_list = []
cov_name: None | str = None

for two_point_real in two_point_reals:
if not two_point_real.is_real():
raise ValueError(
f"The metadata of the TwoPointMeasurement {two_point_real} is not "
f"a measurement of TwoPointReal."
)
if cov_name is None:
cov_name = two_point_real.covariance_name
elif cov_name != two_point_real.covariance_name:
raise ValueError(
f"The TwoPointReal {two_point_real} has a different covariance "
f"name {two_point_real.covariance_name} than the previous "
f"TwoPointReal {cov_name}."
)
index_set = set(two_point_real.indices)
index_set_list.append(index_set)
if len(index_set) != len(two_point_real.indices):
raise ValueError(
f"The indices of the TwoPointReal {two_point_real} " f"are not unique."
)

if all_indices_set & index_set:
for i, index_set_a in enumerate(index_set_list):
if index_set_a & index_set:
raise ValueError(
f"The indices of the TwoPointReal {two_point_reals[i]} "
f"and {two_point_real} overlap."
)
all_indices_set.update(index_set)
62 changes: 62 additions & 0 deletions firecrown/data_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""This module deals with data types.
This module contains data types definitions.
"""

from dataclasses import dataclass

import numpy as np
import numpy.typing as npt

from firecrown.utils import YAMLSerializable
from firecrown.metadata_types import TwoPointReal, TwoPointHarmonic


@dataclass(frozen=True, kw_only=True)
class TwoPointMeasurement(YAMLSerializable):
"""Class defining the data for a two-point measurement.
The class used to store the data for a two-point function measured on a sphere.
This includes the measured two-point function, their indices in the covariance
matrix and the name of the covariance matrix. The corresponding metadata is also
stored.
"""

data: npt.NDArray[np.float64]
indices: npt.NDArray[np.int64]
covariance_name: str
metadata: TwoPointReal | TwoPointHarmonic

def __post_init__(self) -> None:
"""Make sure the data and indices have the same shape."""
if len(self.data.shape) != 1:
raise ValueError("Data should be a 1D array.")

if self.data.shape != self.indices.shape:
raise ValueError("Data and indices should have the same shape.")

if not isinstance(self.metadata, (TwoPointReal, TwoPointHarmonic)):
raise ValueError(
"Metadata should be an instance of TwoPointReal or TwoPointHarmonic."
)

if len(self.data) != self.metadata.n_observations():
raise ValueError("Data and metadata should have the same length.")

def __eq__(self, other) -> bool:
"""Equality test for TwoPointMeasurement objects."""
return (
np.array_equal(self.data, other.data)
and np.array_equal(self.indices, other.indices)
and self.covariance_name == other.covariance_name
and self.metadata == other.metadata
)

def is_real(self) -> bool:
"""Check if the metadata is real."""
return isinstance(self.metadata, TwoPointReal)

def is_harmonic(self) -> bool:
"""Check if the metadata is harmonic."""
return isinstance(self.metadata, TwoPointHarmonic)
6 changes: 2 additions & 4 deletions firecrown/generators/inferred_galaxy_zdist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,15 @@

from numcosmo_py import Ncm

from firecrown.metadata.two_point_types import (
from firecrown.metadata_types import (
InferredGalaxyZDist,
ALL_MEASUREMENT_TYPES,
)
from firecrown.metadata.two_point_types import (
make_measurements_dict,
Measurement,
Galaxies,
CMB,
Clusters,
)
from firecrown.metadata_functions import Measurement


BinsType = TypedDict("BinsType", {"edges": npt.NDArray, "sigma_z": float})
Expand Down
2 changes: 1 addition & 1 deletion firecrown/likelihood/gaussfamily.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _set_covariance(self, covariance: npt.NDArray[np.float64]) -> None:
data_vector = np.concatenate(data_vector_list)
cov = np.zeros((len(indices), len(indices)))

largest_index = np.max(indices)
largest_index = int(np.max(indices))

if not (
covariance.ndim == 2
Expand Down
2 changes: 1 addition & 1 deletion firecrown/likelihood/number_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
SourceGalaxySystematic,
Tracer,
)
from firecrown.metadata.two_point import InferredGalaxyZDist
from firecrown.metadata_types import InferredGalaxyZDist
from firecrown.modeling_tools import ModelingTools
from firecrown.parameters import DerivedParameter, DerivedParameterCollection, ParamsMap
from firecrown.updatable import UpdatableCollection
Expand Down
Loading

0 comments on commit 4eef092

Please sign in to comment.