From 903630d820d17389d6bfdb308860c7f9fbaa1a0c Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Mon, 15 Jul 2024 08:10:37 -0500 Subject: [PATCH 01/10] Even more docstrings updates (#434) * Improve docstrings * Fix errors in example code --- examples/des_y1_3x2pt/des_y1_3x2pt_PT.py | 6 ++- firecrown/updatable.py | 47 ++++++++++++++++++------ 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/examples/des_y1_3x2pt/des_y1_3x2pt_PT.py b/examples/des_y1_3x2pt/des_y1_3x2pt_PT.py index 67f1eb9e..8e507639 100755 --- a/examples/des_y1_3x2pt/des_y1_3x2pt_PT.py +++ b/examples/des_y1_3x2pt/des_y1_3x2pt_PT.py @@ -223,6 +223,7 @@ def run_likelihood() -> None: # Apply the systematics parameters likelihood.update(systematics_params) + tools.update(systematics_params) # Prepare the cosmology object tools.prepare(ccl_cosmo) @@ -296,7 +297,7 @@ def plot_predicted_and_measured_statistics( # pylint: disable=import-outside-toplevel import matplotlib.pyplot as plt - ells = stat0.ells + ells = stat0.ells_for_xi cells = CElls(stat0, stat2, stat3) # Code that computes effect from IA using that Pk2D object @@ -339,6 +340,9 @@ def plot_predicted_and_measured_statistics( fig, ax = plt.subplots(2, 1, sharex=True, figsize=(6, 6)) fig.subplots_adjust(hspace=0) # ax[0].plot(x, y_theory, label="Total") + assert isinstance(ax, np.ndarray) + assert isinstance(ax[0], plt.Axes) + assert isinstance(ax[1], plt.Axes) ax[0].plot(ells, cells.GG, label="GG firecrown") ax[0].plot(ells, cl_GG, ls="--", label="GG CCL") ax[0].plot(ells, -cells.GI, label="-GI firecrown") diff --git a/firecrown/updatable.py b/firecrown/updatable.py index bb3d73bf..ff11af30 100644 --- a/firecrown/updatable.py +++ b/firecrown/updatable.py @@ -38,8 +38,11 @@ class MissingSamplerParameterError(RuntimeError): update is missing a parameter that should have been provided by the sampler. """ - def __init__(self, parameter: str): - """Create the error, with a meaning error message.""" + def __init__(self, parameter: str) -> None: + """Create the error, with a meaningful error message. + + :param parameter: name of the missing parameter + """ self.parameter = parameter msg = ( f"The parameter `{parameter}` is required to update " @@ -63,12 +66,12 @@ class Updatable(ABC): def __init__(self, parameter_prefix: None | str = None) -> None: """Updatable initialization. - :param prefix: prefix for all parameters in this Updatable - Parameters created by firecrown.parameters.create will have a prefix that is given by the prefix argument to the Updatable constructor. This prefix is used to create the full name of the parameter. If `parameter_prefix` is None, then the parameter will have no prefix. + + :param parameter_prefix: prefix for all parameters in this Updatable """ self._updated: bool = False self._returned_derived: bool = False @@ -86,6 +89,9 @@ def __setattr__(self, key: str, value: Any) -> None: We also keep track of all :class:`Updatable` instance variables added, appending a reference to each to :attr:`self._updatables` as well as storing the attribute directly. + + :param key: name of the attribute + :param value: value for the attribute """ if isinstance(value, (Updatable, UpdatableCollection)): self._updatables.append(value) @@ -101,6 +107,9 @@ def set_parameter( Assure this InternalParameter or SamplerParameter has not already been set, and then set it. + + :param key: name of the attribute + :param value: value for the attribute """ if isinstance(value, SamplerParameter): value.set_fullname(self.parameter_prefix, key) @@ -109,7 +118,11 @@ def set_parameter( self.set_internal_parameter(key, value) def set_internal_parameter(self, key: str, value: InternalParameter) -> None: - """Assure this InternalParameter has not already been set, and then set it.""" + """Assure this InternalParameter has not already been set, and then set it. + + :param key: name of the attribute + :param value: value for the attribute + """ if not isinstance(value, InternalParameter): raise TypeError( "Can only add InternalParameter objects to internal_parameters" @@ -124,7 +137,10 @@ def set_internal_parameter(self, key: str, value: InternalParameter) -> None: super().__setattr__(key, value.get_value()) def set_sampler_parameter(self, value: SamplerParameter) -> None: - """Assure this SamplerParameter has not already been set, and then set it.""" + """Assure this SamplerParameter has not already been set, and then set it. + + :param value: value for the attribute + """ if not isinstance(value, SamplerParameter): raise TypeError( "Can only add SamplerParameter objects to sampler_parameters" @@ -182,10 +198,11 @@ def update(self, params: ParamsMap) -> None: def is_updated(self) -> bool: """Determine if the object has been updated. - Return True if the object is currently updated, and False if not. A default-constructed Updatable has not been updated. After `update`, but before `reset`, has been called the object is updated. After `reset` has been called, the object is not currently updated. + + :return: True if the object is currently updated, and False if not. """ return self._updated @@ -240,10 +257,12 @@ def _reset(self) -> None: # pragma: no cover @final def required_parameters(self) -> RequiredParameters: # pragma: no cover - """Returns a RequiredParameters object. + """Returns all information about parameters required by this object. + + This object returned contains the information for all parameters + defined in the implementing class, and any additional parameters. - This object contains the information for all parameters defined in the - implementing class, any additional parameter. + :return: a RequiredParameters object containing all relevant parameters """ sampler_parameters = RequiredParameters(self._sampler_parameters) additional_parameters = self._required_parameters() @@ -262,6 +281,8 @@ def _required_parameters(self) -> RequiredParameters: # pragma: no cover The base class implementation returns a list with all SamplerParameter objects properties. + + :return: a RequiredParameters containing all relevant parameters """ return RequiredParameters([]) @@ -273,6 +294,8 @@ def get_derived_parameters( This occurs once per iteration of the statistical analysis. First call returns the DerivedParameterCollection, further calls return None. + + :return: a collection of derived parameters, or None """ if not self._updated: raise RuntimeError( @@ -296,6 +319,8 @@ def _get_derived_parameters(self) -> DerivedParameterCollection: Derived classes can override this, returning a DerivedParameterCollection containing the derived parameters for the class. The default implementation returns an empty DerivedParameterCollection. + + :return: a collection of derived parameters """ return DerivedParameterCollection([]) @@ -335,7 +360,7 @@ def __init__(self, iterable: None | Iterable[T] = None) -> None: def update(self, params: ParamsMap) -> None: """Update self by calling update() on each contained item. - :param params: new parameter values + :param params: new parameter values for each contained item """ if self._updated: return From b5544e9e910cf1f6abc39faf99d35eb761e2170c Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Mon, 15 Jul 2024 09:06:44 -0500 Subject: [PATCH 02/10] Update numcosmo connector docstrings (#433) * Update numcosmo.py * Restore #pylint: disable directives --- firecrown/connector/numcosmo/numcosmo.py | 285 +++++++++++++++++------ 1 file changed, 218 insertions(+), 67 deletions(-) diff --git a/firecrown/connector/numcosmo/numcosmo.py b/firecrown/connector/numcosmo/numcosmo.py index f8455757..0f88d04c 100644 --- a/firecrown/connector/numcosmo/numcosmo.py +++ b/firecrown/connector/numcosmo/numcosmo.py @@ -24,6 +24,9 @@ def get_hiprim(hi_cosmo: Nc.HICosmo) -> Nc.HIPrimPowerLaw: If hi_cosmo does not have a HIPrim object, a ValueError is raised. If the HIPrim object is not of type HIPrimPowerLaw, a ValueError is raised. + + :param hi_cosmo: NumCosmo HICosmo object + :return: the HIPrim object contained in hi_cosmo """ hiprim = hi_cosmo.peek_submodel_by_mid(Nc.HIPrim.id()) if not hiprim: @@ -49,8 +52,14 @@ def __init__( p_ml: None | Nc.PowspecML = None, p_mnl: None | Nc.PowspecMNL = None, dist: None | Nc.Distance = None, - ): - """Initialize a MappingNumCosmo object.""" + ) -> None: + """Initialize a MappingNumCosmo object. + + :param require_nonlinear_pk: whether to require a nonlinear power spectrum + :param p_ml: optional PowspecML object + :param p_mnl: optional PowspecMNL object + :param dist: optional Distance object + """ super().__init__( # type: ignore require_nonlinear_pk=require_nonlinear_pk, p_ml=p_ml, p_mnl=p_mnl, dist=dist ) @@ -67,11 +76,20 @@ def __init__( self._p_mnl = None def _get_mapping_name(self) -> str: - """Return the mapping name.""" + """Return the mapping name. + + :return: the name of the mapping. + """ return self._mapping_name - def _set_mapping_name(self, value: str): - """Set the mapping name.""" + def _set_mapping_name(self, value: str) -> None: + """Set the mapping name. + + This method also sets the :attr:`mapping` property to a default- + initialized :class:`Mapping` object. + + :param value: the new name of the mapping + """ self._mapping_name = value self.mapping = Mapping() @@ -84,11 +102,17 @@ def _set_mapping_name(self, value: str): ) def _get_require_nonlinear_pk(self) -> bool: - """Return whether nonlinear power spectra are required.""" + """Return whether nonlinear power spectra are required. + + :return: whether nonlinear power spectra are required + """ return self.mapping.require_nonlinear_pk - def _set_require_nonlinear_pk(self, value: bool): - """Set whether nonlinear power spectra are required.""" + def _set_require_nonlinear_pk(self, value: bool) -> None: + """Set whether nonlinear power spectra are required. + + :param value: whether nonlinear power spectra are required + """ self.mapping.require_nonlinear_pk = value require_nonlinear_pk = GObject.Property( @@ -100,11 +124,17 @@ def _set_require_nonlinear_pk(self, value: bool): ) def _get_p_ml(self) -> None | Nc.PowspecML: - """Return the NumCosmo PowspecML object.""" + """Return the NumCosmo PowspecML object. + + :param value: the NumCosmo PowspecML object, or None + """ return self._p_ml - def _set_p_ml(self, value: None | Nc.PowspecML): - """Set the NumCosmo PowspecML object.""" + def _set_p_ml(self, value: None | Nc.PowspecML) -> None: + """Set the NumCosmo PowspecML object. + + :param value: the new value to be set + """ self._p_ml = value p_ml = GObject.Property( @@ -115,11 +145,17 @@ def _set_p_ml(self, value: None | Nc.PowspecML): ) def _get_p_mnl(self) -> None | Nc.PowspecMNL: - """Return the NumCosmo PowspecMNL object.""" + """Return the NumCosmo PowspecMNL object. + + :return: the NumCosmo PowspecMNL object, or None + """ return self._p_mnl - def _set_p_mnl(self, value: None | Nc.PowspecMNL): - """Set the NumCosmo PowspecMNL object.""" + def _set_p_mnl(self, value: None | Nc.PowspecMNL) -> None: + """Set the NumCosmo PowspecMNL object. + + :param value: the new value to be set + """ self._p_mnl = value p_mnl = GObject.Property( @@ -130,10 +166,13 @@ def _set_p_mnl(self, value: None | Nc.PowspecMNL): ) def _get_dist(self) -> None | Nc.Distance: - """Return the NumCosmo Distance object.""" + """Return the NumCosmo Distance object. + + :return: the NumCosmo Distance object, or None + """ return self._dist - def _set_dist(self, value: Nc.Distance): + def _set_dist(self, value: Nc.Distance) -> None: """Set the NumCosmo Distance object.""" self._dist = value @@ -146,10 +185,10 @@ def _set_dist(self, value: Nc.Distance): def set_params_from_numcosmo( self, mset: Ncm.MSet - ): # pylint: disable-msg=too-many-locals - """Return a PyCCLCosmologyConstants object. + ) -> None: # pylint: disable-msg=too-many-locals + """Set the parameters of the contained Mapping object. - This object will have parameters equivalent to those read from NumCosmo. + :param mset: the NumCosmo MSet object from which to get the parameters """ hi_cosmo = mset.peek(Nc.HICosmo.id()) assert isinstance(hi_cosmo, Nc.HICosmo) @@ -204,8 +243,14 @@ def set_params_from_numcosmo( ) # pylint: enable=duplicate-code - def calculate_ccl_args(self, mset: Ncm.MSet): # pylint: disable-msg=too-many-locals - """Calculate the arguments necessary for CCL for this sample.""" + def calculate_ccl_args( # pylint: disable-msg=too-many-locals + self, mset: Ncm.MSet + ) -> dict[str, Any]: + """Calculate the arguments necessary for CCL for this sample. + + :param mset: the NumCosmo MSet object from which to get the parameters + :return: a dictionary of the arguments required by CCL + """ ccl_args: dict[str, Any] = {} hi_cosmo = mset.peek(Nc.HICosmo.id()) assert isinstance(hi_cosmo, Nc.HICosmo) @@ -276,7 +321,12 @@ def calculate_ccl_args(self, mset: Ncm.MSet): # pylint: disable-msg=too-many-lo return ccl_args def create_params_map(self, model_list: list[str], mset: Ncm.MSet) -> ParamsMap: - """Create a ParamsMap from a NumCosmo MSet.""" + """Create a ParamsMap from a NumCosmo MSet. + + :param model_list: list of model names + :param mset: the NumCosmo MSet object from which to get the parameters + :return: a ParamsMap containing the parameters of the models in model_list + """ params_map = ParamsMap() for model_ns in model_list: mid = mset.get_id_by_ns(model_ns) @@ -313,6 +363,10 @@ class NumCosmoData(Ncm.Data): __gtype_name__ = "FirecrownNumCosmoData" def __init__(self): + """Initialize a NumCosmoData object. + + Default values are provided for all attributes; most default are None. + """ super().__init__() self.likelihood: Likelihood self.tools: ModelingTools @@ -327,11 +381,17 @@ def __init__(self): self.set_init(True) def _get_model_list(self) -> list[str]: - """Return the list of models.""" + """Return the list of model names. + + :return: the names of the contained models + """ return self._model_list - def _set_model_list(self, value: list[str]): - """Set the list of models.""" + def _set_model_list(self, value: list[str]) -> None: + """Set the list of model names. + + :param value: the new of model names + """ self._model_list = value model_list = GObject.Property( @@ -342,11 +402,17 @@ def _set_model_list(self, value: list[str]): ) def _get_nc_mapping(self) -> MappingNumCosmo: - """Return the MappingNumCosmo object.""" + """Return the MappingNumCosmo object. + + :return: the MappingNumCosmo object + """ return self._nc_mapping - def _set_nc_mapping(self, value: MappingNumCosmo): - """Set the MappingNumCosmo object.""" + def _set_nc_mapping(self, value: MappingNumCosmo) -> None: + """Set the MappingNumCosmo object. + + :param: the new value for the MappingNumCosmo object + """ self._nc_mapping = value nc_mapping = GObject.Property( @@ -356,7 +422,7 @@ def _set_nc_mapping(self, value: MappingNumCosmo): setter=_set_nc_mapping, ) - def _set_likelihood_from_factory(self): + def _set_likelihood_from_factory(self) -> None: """Deserialize the likelihood.""" assert self._likelihood_source is not None assert self._likelihood_build_parameters is not None @@ -369,11 +435,17 @@ def _set_likelihood_from_factory(self): self.tools = tools def _get_likelihood_source(self) -> None | str: - """Return the likelihood string defining the factory function.""" + """Return the likelihood string defining the factory function. + + :return: the filename of the likelihood factory function + """ return self._likelihood_source def _set_likelihood_source(self, value: None | str): - """Set the likelihood string defining the factory function.""" + """Set the likelihood string defining the factory function. + + :param value: the filename of the likelihood factory function + """ if value is not None: self._likelihood_source = value if self._starting_deserialization: @@ -397,8 +469,11 @@ def _get_likelihood_build_parameters(self) -> None | Ncm.VarDict: self._likelihood_build_parameters.convert_to_basic_dict() ) - def _set_likelihood_build_parameters(self, value: None | Ncm.VarDict): - """Set the likelihood build parameters.""" + def _set_likelihood_build_parameters(self, value: None | Ncm.VarDict) -> None: + """Set the likelihood build parameters. + + :param value: the parameters used to build the likelihood + """ self._likelihood_build_parameters = NamedParameters() if value is not None: self._likelihood_build_parameters.set_from_basic_dict( @@ -427,10 +502,17 @@ def new_from_likelihood( nc_mapping: MappingNumCosmo, likelihood_source: None | str = None, likelihood_build_parameters: None | NamedParameters = None, - ): + ) -> "NumCosmoData": """Initialize a NumCosmoGaussCov object. This object represents a Gaussian likelihood with a constant covariance. + + :param likelihood: the likelihood object + :param model_list: the list of model names + :param tools: the modeling tools + :param nc_mapping: the mapping object + :param likelihood_source: the filename for the likelihood factory function + :param likelihood_build_parameters: the build parameters of the likelihood """ nc_data: NumCosmoData = GObject.new( cls, @@ -447,15 +529,21 @@ def new_from_likelihood( return nc_data - def do_get_length(self): # pylint: disable-msg=arguments-differ - """Implements the virtual Ncm.Data method get_length.""" + def do_get_length(self) -> int: # pylint: disable-msg=arguments-differ + """Implements the virtual Ncm.Data method get_length. + + :return: the number of data points in the likelihood + """ return self.len - def do_get_dof(self): # pylint: disable-msg=arguments-differ - """Implements the virtual Ncm.Data method get_dof.""" + def do_get_dof(self) -> int: # pylint: disable-msg=arguments-differ + """Implements the virtual Ncm.Data method get_dof. + + :return: the number of degrees of freedom in the likelihood + """ return self.dof - def do_begin(self): # pylint: disable-msg=arguments-differ + def do_begin(self) -> None: # pylint: disable-msg=arguments-differ """Implements the virtual Ncm.Data method `begin`. This method usually do some groundwork in the data @@ -464,11 +552,15 @@ def do_begin(self): # pylint: disable-msg=arguments-differ during `begin` once and then used afterwards. """ - def do_prepare(self, mset: Ncm.MSet): # pylint: disable-msg=arguments-differ + def do_prepare( # pylint: disable-msg=arguments-differ + self, mset: Ncm.MSet + ) -> None: """Implements the virtual method Ncm.Data `prepare`. This method should do all the necessary calculations using mset to be able to calculate the likelihood afterwards. + + :param mset: the model set """ self.dof = self.len - mset.fparams_len() self.likelihood.reset() @@ -485,11 +577,13 @@ def do_prepare(self, mset: Ncm.MSet): # pylint: disable-msg=arguments-differ self.tools.update(params_map) self.tools.prepare(self.ccl_cosmo) - def do_m2lnL_val(self, _): # pylint: disable-msg=arguments-differ + def do_m2lnL_val(self, _) -> float: # pylint: disable-msg=arguments-differ """Implements the virtual method `m2lnL`. This method should calculate the value of the likelihood for the model set `mset`. + + :param _: unused, but required by interface """ loglike = self.likelihood.compute_loglike(self.tools) return -2.0 * loglike @@ -504,7 +598,7 @@ class NumCosmoGaussCov(Ncm.DataGaussCov): __gtype_name__ = "FirecrownNumCosmoGaussCov" - def __init__(self): + def __init__(self) -> None: """Initialize a NumCosmoGaussCov object. This class is a subclass of Ncm.DataGaussCov and implements NumCosmo @@ -515,7 +609,7 @@ def __init__(self): argument, and the properties must be set after construction. In python one should use the `new_from_likelihood` method to create a - NumCosmoGaussCov object from a ConstGaussian object. This constuctor + NumCosmoGaussCov object from a ConstGaussian object. This constructor has the correct signature for type checking. """ super().__init__() @@ -531,11 +625,17 @@ def __init__(self): self._starting_deserialization: bool = False def _get_model_list(self) -> list[str]: - """Return the list of models.""" + """Return the list of models. + + :return: the current names of the models + """ return self._model_list - def _set_model_list(self, value: list[str]): - """Set the list of models.""" + def _set_model_list(self, value: list[str]) -> None: + """Set the list of models. + + :param value: the new list of model names + """ self._model_list = value model_list = GObject.Property( @@ -546,11 +646,17 @@ def _set_model_list(self, value: list[str]): ) def _get_nc_mapping(self) -> MappingNumCosmo: - """Return the :class:`MappingNumCosmo` object.""" + """Return the :class:`MappingNumCosmo` object. + + :return: the current value of the mapping + """ return self._nc_mapping def _set_nc_mapping(self, value: MappingNumCosmo): - """Set the MappingNumCosmo object.""" + """Set the MappingNumCosmo object. + + :param: the new value for the MappingNumCosmo object + """ self._nc_mapping = value nc_mapping = GObject.Property( @@ -560,7 +666,7 @@ def _set_nc_mapping(self, value: MappingNumCosmo): setter=_set_nc_mapping, ) - def _configure_object(self): + def _configure_object(self) -> None: """Configure the object.""" assert self.likelihood is not None @@ -584,7 +690,7 @@ def _configure_object(self): self.set_init(True) - def _set_likelihood_from_factory(self): + def _set_likelihood_from_factory(self) -> None: """Deserialize the likelihood.""" assert self._likelihood_source is not None assert self._likelihood_build_parameters is not None @@ -598,11 +704,17 @@ def _set_likelihood_from_factory(self): self._configure_object() def _get_likelihood_source(self) -> None | str: - """Return the likelihood string defining the factory function.""" + """Return the likelihood string defining the factory function. + + :return: the filename of the likelihood factory function + """ return self._likelihood_source - def _set_likelihood_source(self, value: None | str): - """Set the likelihood string defining the factory function.""" + def _set_likelihood_source(self, value: None | str) -> None: + """Set the likelihood string defining the factory function. + + :param value: the filename of the likelihood factory function + """ if value is not None: self._likelihood_source = value if self._starting_deserialization: @@ -619,15 +731,21 @@ def _set_likelihood_source(self, value: None | str): ) def _get_likelihood_build_parameters(self) -> None | Ncm.VarDict: - """Return the likelihood build parameters.""" + """Return the likelihood build parameters. + + :return: the likelihood build parameters + """ if self._likelihood_build_parameters is None: return None return dict_to_var_dict( self._likelihood_build_parameters.convert_to_basic_dict() ) - def _set_likelihood_build_parameters(self, value: None | Ncm.VarDict): - """Set the likelihood build parameters.""" + def _set_likelihood_build_parameters(self, value: None | Ncm.VarDict) -> None: + """Set the likelihood build parameters. + + :param value: the likelihood build parameters + """ self._likelihood_build_parameters = NamedParameters() if value is not None: self._likelihood_build_parameters.set_from_basic_dict( @@ -659,6 +777,11 @@ def new_from_likelihood( """Initialize a NumCosmoGaussCov object. This object represents a Gaussian likelihood with a constant covariance. + :param likelihood: the likelihood object + :param model_list: the list of model names + :param nc_mapping: the mapping object + :param likelihood_source: the filename for the likelihood factory function + :param likelihood_build_parameters: the build parameters of the likelihood """ cov = likelihood.get_cov() nrows, ncols = cov.shape @@ -684,15 +807,21 @@ def new_from_likelihood( return nc_gauss_cov - def do_get_length(self): # pylint: disable-msg=arguments-differ - """Implements the virtual `Ncm.Data` method `get_length`.""" + def do_get_length(self) -> int: # pylint: disable-msg=arguments-differ + """Implements the virtual `Ncm.Data` method `get_length`. + + :return: the number of data points in the likelihood + """ return self.len - def do_get_dof(self): # pylint: disable-msg=arguments-differ - """Implements the virtual `Ncm.Data` method `get_dof`.""" + def do_get_dof(self) -> int: # pylint: disable-msg=arguments-differ + """Implements the virtual `Ncm.Data` method `get_dof`. + + :return: the number of degrees of freedom in the likelihood + """ return self.dof - def do_begin(self): # pylint: disable-msg=arguments-differ + def do_begin(self) -> None: # pylint: disable-msg=arguments-differ """Implements the virtual `Ncm.Data` method `begin`. This method usually do some groundwork in the data @@ -701,11 +830,14 @@ def do_begin(self): # pylint: disable-msg=arguments-differ during `begin` once and then used afterwards. """ - def do_prepare(self, mset: Ncm.MSet): # pylint: disable-msg=arguments-differ + def do_prepare( # pylint: disable-msg=arguments-differ + self, mset: Ncm.MSet + ) -> None: """Implements the virtual method Ncm.Data `prepare`. This method should do all the necessary calculations using mset to be able to calculate the likelihood afterwards. + :param mset: the model set """ self.dof = self.len - mset.fparams_len() self.likelihood.reset() @@ -724,11 +856,14 @@ def do_prepare(self, mset: Ncm.MSet): # pylint: disable-msg=arguments-differ self.tools.prepare(self.ccl_cosmo) # pylint: disable-next=arguments-differ - def do_mean_func(self, _, vp): + def do_mean_func(self, _, vp) -> None: """Implements the virtual `Ncm.DataGaussCov` method `mean_func`. This method should compute the theoretical mean for the gaussian distribution. + + :param _: unused, but required by interface + :param vp: the vector to set """ theory_vector = self.likelihood.compute_theory_vector(self.tools) vp.set_array(theory_vector) @@ -756,7 +891,14 @@ def __init__( build_parameters: NamedParameters, mapping: MappingNumCosmo, model_list: list[str], - ): + ) -> None: + """Initialize a NumCosmoFactory. + + :param likelihood_source: the filename for the likelihood factory function + :param build_parameters: the build parameters + :param mapping: the mapping + :param model_list: the model list + """ likelihood, tools = load_likelihood(likelihood_source, build_parameters) self.data: NumCosmoGaussCov | NumCosmoData @@ -781,13 +923,22 @@ def __init__( ) def get_data(self) -> Ncm.Data: - """This method return the appropriated Ncm.Data class to be used by NumCosmo.""" + """This method return the appropriate Ncm.Data class to be used by NumCosmo. + + :return: the data used by NumCosmo + """ return self.data def get_mapping(self) -> MappingNumCosmo: - """This method return the current MappingNumCosmo.""" + """This method return the current MappingNumCosmo. + + :return: the current mapping. + """ return self.mapping def get_firecrown_likelihood(self) -> Likelihood: - """This method returns the firecrown Likelihood.""" + """This method returns the Firecrown Likelihood. + + :return: the likelihood + """ return self.data.likelihood From a6c217a84b43e49a7341566f98e0507d7de159d6 Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Tue, 16 Jul 2024 15:57:15 -0300 Subject: [PATCH 03/10] Two point data (#432) * Reorganizing two-point types in a different module. * Adding optional data to two_point dataclasses. * Updated usage of two_point dataclasses. * Removed required data_vector when computing theory_vector. * Testing new constructors. * Comparing new and old constructors. * Testing new measurement code. * New constructor for gauss family for creating ready likelihoods. * New tutorial for two point factories. * Constructors to create ready. * Added SHEAR_MINUS and SHEAR_PLUS. * Created factories to create both ready and not-ready objects. * Updated tests to new types. * Adding likelihood creation to tutorial. * Adding support for multiple measurements in the same bin. * Added likelihood comparison on two_point_factories. * Splitting metadata test file. * Updating test to deal with enum str representations in python 3.10. * Reorganizing fixtures in conftest.py. * Testing ell and theta wrong shape check. * Testing multiple data_types. * Testing inverting tracer order. * Testing name type matching. * Testing naming convention enforcing. * Test the consistence checks for a set of two-point dataclasses. * Testing indices and covariance consistency. * Testing sacc covariance matrix presence. * Testing ambiguous types. * Reorganized imports and added missing tests to TwoPointCWindow. * Updated tutorials to match new module names. * Test unsupported type exception. * Testing metadata only constructors. * Testing ConstGaussian create ready. * Fixing testing covariance size. * Testing the use of non ready statistics. --- firecrown/generators/inferred_galaxy_zdist.py | 78 +- firecrown/likelihood/gaussfamily.py | 39 +- firecrown/likelihood/number_counts.py | 45 +- firecrown/likelihood/two_point.py | 338 ++++-- firecrown/likelihood/weak_lensing.py | 37 +- firecrown/metadata/two_point.py | 1030 ++++++----------- firecrown/metadata/two_point_types.py | 752 ++++++++++++ firecrown/utils.py | 34 +- tests/conftest.py | 420 ++++++- .../generators/test_inferred_galaxy_zdist.py | 86 +- .../statistic/source/test_source.py | 30 +- .../gauss_family/statistic/test_two_point.py | 128 ++ .../gauss_family/test_const_gaussian.py | 68 ++ tests/metadata/sacc_name_mapping.py | 2 +- tests/metadata/test_metadata_two_point.py | 577 +++++---- .../test_metadata_two_point_measurement.py | 217 ++++ .../metadata/test_metadata_two_point_sacc.py | 415 ++++++- .../metadata/test_metadata_two_point_types.py | 581 ++++++++++ tests/test_utils.py | 19 + tutorial/_quarto.yml | 3 + tutorial/inferred_zdist.qmd | 15 +- tutorial/inferred_zdist_generators.qmd | 16 +- tutorial/inferred_zdist_serialization.qmd | 6 +- tutorial/two_point_factories.qmd | 206 ++++ tutorial/two_point_generators.qmd | 2 +- 25 files changed, 3887 insertions(+), 1257 deletions(-) create mode 100644 firecrown/metadata/two_point_types.py create mode 100644 tests/metadata/test_metadata_two_point_measurement.py create mode 100644 tests/metadata/test_metadata_two_point_types.py create mode 100644 tutorial/two_point_factories.qmd diff --git a/firecrown/generators/inferred_galaxy_zdist.py b/firecrown/generators/inferred_galaxy_zdist.py index 8ac9fa50..430a2fc3 100644 --- a/firecrown/generators/inferred_galaxy_zdist.py +++ b/firecrown/generators/inferred_galaxy_zdist.py @@ -12,16 +12,19 @@ from numcosmo_py import Ncm -from firecrown.metadata.two_point import ( +from firecrown.metadata.two_point_types import ( InferredGalaxyZDist, + ALL_MEASUREMENT_TYPES, +) +from firecrown.metadata.two_point_types import ( + make_measurements_dict, Measurement, Galaxies, CMB, Clusters, - ALL_MEASUREMENT_TYPES, - make_measurement_dict, ) + BinsType = TypedDict("BinsType", {"edges": npt.NDArray, "sigma_z": float}) Y1_ALPHA = 0.94 @@ -159,7 +162,7 @@ def binned_distribution( sigma_z: float, z: npt.NDArray, name: str, - measurement: Measurement, + measurements: set[Measurement], use_autoknot: bool = False, autoknots_reltol: float = 1.0e-4, autoknots_abstol: float = 1.0e-15, @@ -171,7 +174,7 @@ def binned_distribution( :param sigma_z: The resolution parameter :param z: The redshifts at which to evaluate the distribution :param name: The name of the distribution - :param measured_type: The measured type of the distribution + :param measurements: The set of measurements of the distribution :param use_autoknot: Whether to use the NotAKnot algorithm of NumCosmo :param autoknots_reltol: The relative tolerance for the NotAKnot algorithm :param autoknots_abstol: The absolute tolerance for the NotAKnot algorithm @@ -217,7 +220,7 @@ def _P(z, _): ) return InferredGalaxyZDist( - bin_name=name, z=z_knots, dndz=dndz, measurement=measurement + bin_name=name, z=z_knots, dndz=dndz, measurements=measurements ) @@ -250,31 +253,40 @@ def generate(self) -> npt.NDArray: Grid1D = LinearGrid1D | RawGrid1D -def make_measurement(value: Measurement | dict[str, Any]) -> Measurement: +def make_measurements( + value: set[Measurement] | list[dict[str, Any]] +) -> set[Measurement]: """Create a Measurement object from a dictionary.""" - if isinstance(value, ALL_MEASUREMENT_TYPES): + if isinstance(value, set) and all( + isinstance(v, ALL_MEASUREMENT_TYPES) for v in value + ): return value - if not isinstance(value, dict): - raise ValueError(f"Invalid Measurement: {value} is not a dictionary") + measurements: set[Measurement] = set() + for measurement_dict in value: + if not isinstance(measurement_dict, dict): + raise ValueError(f"Invalid Measurement: {value} is not a dictionary") - if "subject" not in value: - raise ValueError("Invalid Measurement: dictionary does not contain 'subject'") - - subject = value["subject"] - - match subject: - case "Galaxies": - return Galaxies[value["property"]] - case "CMB": - return CMB[value["property"]] - case "Clusters": - return Clusters[value["property"]] - case _: + if "subject" not in measurement_dict: raise ValueError( - f"Invalid Measurement: subject: '{subject}' is not recognized" + "Invalid Measurement: dictionary does not contain 'subject'" ) + subject = measurement_dict["subject"] + + match subject: + case "Galaxies": + measurements.update({Galaxies[measurement_dict["property"]]}) + case "CMB": + measurements.update({CMB[measurement_dict["property"]]}) + case "Clusters": + measurements.update({Clusters[measurement_dict["property"]]}) + case _: + raise ValueError( + f"Invalid Measurement: subject: '{subject}' is not recognized" + ) + return measurements + class ZDistLSSTSRDBin(BaseModel): """LSST Inferred galaxy redshift distributions in bins.""" @@ -286,16 +298,16 @@ class ZDistLSSTSRDBin(BaseModel): sigma_z: float z: Annotated[Grid1D, Field(union_mode="left_to_right")] bin_name: str - measurement: Annotated[Measurement, BeforeValidator(make_measurement)] + measurements: Annotated[set[Measurement], BeforeValidator(make_measurements)] use_autoknot: bool = False autoknots_reltol: float = 1.0e-4 autoknots_abstol: float = 1.0e-15 - @field_serializer("measurement") + @field_serializer("measurements") @classmethod - def serialize_measurement(cls, value: Measurement) -> dict: + def serialize_measurements(cls, value: set[Measurement]) -> list[dict]: """Serialize the Measurement.""" - return make_measurement_dict(value) + return make_measurements_dict(value) def generate(self, zdist: ZDistLSSTSRD) -> InferredGalaxyZDist: """Generate the inferred galaxy redshift distribution in bins.""" @@ -305,7 +317,7 @@ def generate(self, zdist: ZDistLSSTSRD) -> InferredGalaxyZDist: sigma_z=self.sigma_z, z=self.z.generate(), name=self.bin_name, - measurement=self.measurement, + measurements=self.measurements, use_autoknot=self.use_autoknot, autoknots_reltol=self.autoknots_reltol, autoknots_abstol=self.autoknots_abstol, @@ -339,7 +351,7 @@ def generate(self) -> list[InferredGalaxyZDist]: sigma_z=Y1_LENS_BINS["sigma_z"], z=RawGrid1D(values=[0.0, 3.0]), bin_name=f"lens_{zpl:.1f}_{zpu:.1f}_y1", - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, use_autoknot=True, autoknots_reltol=1.0e-5, ) @@ -358,7 +370,7 @@ def generate(self) -> list[InferredGalaxyZDist]: sigma_z=Y1_SOURCE_BINS["sigma_z"], z=RawGrid1D(values=[0.0, 3.0]), bin_name=f"source_{zpl:.1f}_{zpu:.1f}_y1", - measurement=Galaxies.SHEAR_E, + measurements={Galaxies.SHEAR_E}, use_autoknot=True, autoknots_reltol=1.0e-5, ) @@ -377,7 +389,7 @@ def generate(self) -> list[InferredGalaxyZDist]: sigma_z=Y10_LENS_BINS["sigma_z"], z=RawGrid1D(values=[0.0, 3.0]), bin_name=f"lens_{zpl:.1f}_{zpu:.1f}_y10", - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, use_autoknot=True, autoknots_reltol=1.0e-5, ) @@ -396,7 +408,7 @@ def generate(self) -> list[InferredGalaxyZDist]: sigma_z=Y10_SOURCE_BINS["sigma_z"], z=RawGrid1D(values=[0.0, 3.0]), bin_name=f"source_{zpl:.1f}_{zpu:.1f}_y10", - measurement=Galaxies.SHEAR_E, + measurements={Galaxies.SHEAR_E}, use_autoknot=True, autoknots_reltol=1.0e-5, ) diff --git a/firecrown/likelihood/gaussfamily.py b/firecrown/likelihood/gaussfamily.py index a5a38936..80ff5bbc 100644 --- a/firecrown/likelihood/gaussfamily.py +++ b/firecrown/likelihood/gaussfamily.py @@ -156,6 +156,16 @@ def __init__( self.theory_vector: None | npt.NDArray[np.double] = None self.data_vector: None | npt.NDArray[np.double] = None + @classmethod + def create_ready( + cls, statistics: Sequence[Statistic], covariance: npt.NDArray[np.float64] + ) -> GaussFamily: + """Create a GaussFamily object in the READY state.""" + obj = cls(statistics) + obj._set_covariance(covariance) + obj.state = State.READY + return obj + @enforce_states( initial=State.READY, terminal=State.UPDATED, @@ -199,12 +209,26 @@ def read(self, sacc_data: sacc.Sacc) -> None: ) raise RuntimeError(msg) + for stat in self.statistics: + stat.read(sacc_data) + covariance = sacc_data.covariance.dense + self._set_covariance(covariance) + + def _set_covariance(self, covariance): + """Set the covariance matrix. + + This method is used to set the covariance matrix and perform the + necessary calculations to prepare the likelihood for computation. + """ indices_list = [] data_vector_list = [] for stat in self.statistics: - stat.read(sacc_data) + if not stat.statistic.ready: + raise RuntimeError( + f"The statistic {stat.statistic} is not ready to be used." + ) if stat.statistic.sacc_indices is None: raise RuntimeError( f"The statistic {stat.statistic} has no sacc_indices." @@ -216,6 +240,19 @@ def read(self, sacc_data: sacc.Sacc) -> None: data_vector = np.concatenate(data_vector_list) cov = np.zeros((len(indices), len(indices))) + largest_index = np.max(indices) + + if not ( + covariance.ndim == 2 + and covariance.shape[0] == covariance.shape[1] + and largest_index < covariance.shape[0] + ): + raise ValueError( + f"The covariance matrix has shape {covariance.shape}, " + f"but the expected shape is at least " + f"{(largest_index + 1, largest_index + 1)}." + ) + for new_i, old_i in enumerate(indices): for new_j, old_j in enumerate(indices): cov[new_i, new_j] = covariance[old_i, old_j] diff --git a/firecrown/likelihood/number_counts.py b/firecrown/likelihood/number_counts.py index 69f1755c..48b3a3ae 100644 --- a/firecrown/likelihood/number_counts.py +++ b/firecrown/likelihood/number_counts.py @@ -487,9 +487,9 @@ class PhotoZShiftFactory(BaseModel): Field(description="The type of the systematic."), ] = "PhotoZShiftFactory" - def create(self, inferred_zdist: InferredGalaxyZDist) -> PhotoZShift: + def create(self, bin_name: str) -> PhotoZShift: """Create a PhotoZShift object with the given tracer name.""" - return PhotoZShift(inferred_zdist.bin_name) + return PhotoZShift(bin_name) def create_global(self) -> PhotoZShift: """Create a PhotoZShift object with the given tracer name.""" @@ -506,9 +506,9 @@ class LinearBiasSystematicFactory(BaseModel): Field(description="The type of the systematic."), ] = "LinearBiasSystematicFactory" - def create(self, inferred_zdist: InferredGalaxyZDist) -> LinearBiasSystematic: + def create(self, bin_name: str) -> LinearBiasSystematic: """Create a LinearBiasSystematic object with the given tracer name.""" - return LinearBiasSystematic(inferred_zdist.bin_name) + return LinearBiasSystematic(bin_name) def create_global(self) -> LinearBiasSystematic: """Create a LinearBiasSystematic object with the given tracer name.""" @@ -525,9 +525,9 @@ class PTNonLinearBiasSystematicFactory(BaseModel): Field(description="The type of the systematic."), ] = "PTNonLinearBiasSystematicFactory" - def create(self, inferred_zdist: InferredGalaxyZDist) -> PTNonLinearBiasSystematic: + def create(self, bin_name: str) -> PTNonLinearBiasSystematic: """Create a PTNonLinearBiasSystematic object with the given tracer name.""" - return PTNonLinearBiasSystematic(inferred_zdist.bin_name) + return PTNonLinearBiasSystematic(bin_name) def create_global(self) -> PTNonLinearBiasSystematic: """Create a PTNonLinearBiasSystematic object with the given tracer name.""" @@ -544,11 +544,9 @@ class MagnificationBiasSystematicFactory(BaseModel): Field(description="The type of the systematic."), ] = "MagnificationBiasSystematicFactory" - def create( - self, inferred_zdist: InferredGalaxyZDist - ) -> MagnificationBiasSystematic: + def create(self, bin_name: str) -> MagnificationBiasSystematic: """Create a MagnificationBiasSystematic object with the given tracer name.""" - return MagnificationBiasSystematic(inferred_zdist.bin_name) + return MagnificationBiasSystematic(bin_name) def create_global(self) -> MagnificationBiasSystematic: """Create a MagnificationBiasSystematic object with the given tracer name.""" @@ -565,14 +563,12 @@ class ConstantMagnificationBiasSystematicFactory(BaseModel): Field(description="The type of the systematic."), ] = "ConstantMagnificationBiasSystematicFactory" - def create( - self, inferred_zdist: InferredGalaxyZDist - ) -> ConstantMagnificationBiasSystematic: + def create(self, bin_name: str) -> ConstantMagnificationBiasSystematic: """Create a ConstantMagnificationBiasSystematic object. Use the inferred_zdist to create the systematic. """ - return ConstantMagnificationBiasSystematic(inferred_zdist.bin_name) + return ConstantMagnificationBiasSystematic(bin_name) def create_global(self) -> ConstantMagnificationBiasSystematic: """Create a ConstantMagnificationBiasSystematic object. @@ -618,7 +614,7 @@ def create(self, inferred_zdist: InferredGalaxyZDist) -> NumberCounts: return self._cache[inferred_zdist_id] systematics: list[SourceGalaxySystematic[NumberCountsArgs]] = [ - systematic_factory.create(inferred_zdist) + systematic_factory.create(inferred_zdist.bin_name) for systematic_factory in self.per_bin_systematics ] systematics.extend(self._global_systematics_instances) @@ -627,3 +623,22 @@ def create(self, inferred_zdist: InferredGalaxyZDist) -> NumberCounts: self._cache[inferred_zdist_id] = nc return nc + + def create_from_metadata_only( + self, + sacc_tracer: str, + ) -> NumberCounts: + """Create an WeakLensing object with the given tracer name and scale.""" + sacc_tracer_id = hash(sacc_tracer) # Improve this + if sacc_tracer_id in self._cache: + return self._cache[sacc_tracer_id] + systematics: list[SourceGalaxySystematic[NumberCountsArgs]] = [ + systematic_factory.create(sacc_tracer) + for systematic_factory in self.per_bin_systematics + ] + systematics.extend(self._global_systematics_instances) + + nc = NumberCounts(sacc_tracer=sacc_tracer, systematics=systematics) + self._cache[sacc_tracer_id] = nc + + return nc diff --git a/firecrown/likelihood/two_point.py b/firecrown/likelihood/two_point.py index 986b138b..0f1becbd 100644 --- a/firecrown/likelihood/two_point.py +++ b/firecrown/likelihood/two_point.py @@ -31,14 +31,25 @@ Statistic, TheoryVector, ) -from firecrown.metadata.two_point import ( - Galaxies, - InferredGalaxyZDist, +from firecrown.metadata.two_point_types import ( TRACER_NAMES_TOTAL, + InferredGalaxyZDist, + Galaxies, + Measurement, +) + +from firecrown.metadata.two_point import ( TracerNames, TwoPointCells, + TwoPointCWindow, + TwoPointXiTheta, + TwoPointCellsIndex, + TwoPointXiThetaIndex, Window, extract_window_function, + check_two_point_consistence_harmonic, + check_two_point_consistence_real, + measurements_from_index, ) from firecrown.modeling_tools import ModelingTools from firecrown.updatable import UpdatableCollection @@ -223,53 +234,58 @@ def apply_theta_min_max( def use_source_factory( inferred_galaxy_zdist: InferredGalaxyZDist, + measurement: Measurement, wl_factory: WeakLensingFactory | None = None, nc_factory: NumberCountsFactory | None = None, ) -> WeakLensing | NumberCounts: """Apply the factory to the inferred galaxy redshift distribution.""" source: WeakLensing | NumberCounts - match inferred_galaxy_zdist.measurement: + if measurement not in inferred_galaxy_zdist.measurements: + raise ValueError( + f"Measurement {measurement} not found in inferred galaxy redshift " + f"distribution {inferred_galaxy_zdist.bin_name}!" + ) + + match measurement: case Galaxies.COUNTS: assert nc_factory is not None source = nc_factory.create(inferred_galaxy_zdist) - case Galaxies.SHEAR_E | Galaxies.SHEAR_T: + case ( + Galaxies.SHEAR_E + | Galaxies.SHEAR_T + | Galaxies.SHEAR_MINUS + | Galaxies.SHEAR_PLUS + ): assert wl_factory is not None source = wl_factory.create(inferred_galaxy_zdist) case _: - raise ValueError( - f"Measurement {inferred_galaxy_zdist.measurement} not supported!" - ) + raise ValueError(f"Measurement {measurement} not supported!") return source -def create_sacc(metadata: list[TwoPointCells]): - """Fill the SACC file with the inferred galaxy redshift distributions.""" - sacc_data = sacc.Sacc() - - dv = [] - for cell in metadata: - for inferred_galaxy_zdist in (cell.XY.x, cell.XY.y): - if inferred_galaxy_zdist.bin_name not in sacc_data.tracers: - sacc_data.add_tracer( - "NZ", - inferred_galaxy_zdist.bin_name, - inferred_galaxy_zdist.z, - inferred_galaxy_zdist.dndz, - ) - cells = np.ones_like(cell.ells) - sacc_data.add_ell_cl( - cell.get_sacc_name(), - cell.XY.x.bin_name, - cell.XY.y.bin_name, - ell=cell.ells, - x=cells, - ) - dv.append(cells) - - delta_v = np.concatenate(dv, axis=0) - sacc_data.add_covariance(np.diag(np.ones_like(delta_v))) - - return sacc_data +def use_source_factory_metadata_only( + sacc_tracer: str, + measurement: Measurement, + wl_factory: WeakLensingFactory | None = None, + nc_factory: NumberCountsFactory | None = None, +) -> WeakLensing | NumberCounts: + """Apply the factory to create a source from metadata only.""" + source: WeakLensing | NumberCounts + match measurement: + case Galaxies.COUNTS: + assert nc_factory is not None + source = nc_factory.create_from_metadata_only(sacc_tracer) + case ( + Galaxies.SHEAR_E + | Galaxies.SHEAR_T + | Galaxies.SHEAR_MINUS + | Galaxies.SHEAR_PLUS + ): + assert wl_factory is not None + source = wl_factory.create_from_metadata_only(sacc_tracer) + case _: + raise ValueError(f"Measurement {measurement} not supported!") + return source class TwoPoint(Statistic): @@ -374,68 +390,249 @@ def __init__( assert isinstance(source0, Source) assert isinstance(source1, Source) - self.sacc_data_type: str = sacc_data_type + self.sacc_data_type: str + self.ccl_kind: str self.source0: Source = source0 - self.source1 = source1 - self.ell_for_xi_config: dict[str, int] = copy.deepcopy(ELL_FOR_XI_DEFAULTS) + self.source1: Source = source1 + + self.ell_for_xi_config: dict[str, int] + self.ell_or_theta_config: None | EllOrThetaConfig + self.ell_or_theta_min: None | float | int + self.ell_or_theta_max: None | float | int + self.window: None | Window + self.data_vector: None | DataVector + self.theory_vector: None | TheoryVector + self.sacc_tracers: TracerNames + self.ells: None | npt.NDArray[np.int64] + self.thetas: None | npt.NDArray[np.float64] + self.mean_ells: None | npt.NDArray[np.float64] + self.ells_for_xi: None | npt.NDArray[np.int64] + self.cells: dict[TracerNames, npt.NDArray[np.float64]] + + self._init_empty_default_attribs() if ell_for_xi is not None: self.ell_for_xi_config.update(ell_for_xi) - # What is the difference between the following 3 instance variables? - # ell_or_theta - # _ell_or_theta - # ell_or_theta_ self.ell_or_theta_config = ell_or_theta self.ell_or_theta_min = ell_or_theta_min self.ell_or_theta_max = ell_or_theta_max - self.window: None | Window = None - self.data_vector: None | DataVector = None - self.theory_vector: None | TheoryVector = None + self._set_ccl_kind(sacc_data_type) - self.sacc_tracers: TracerNames - self.ells: None | npt.NDArray[np.int64] = None - self.thetas: None | npt.NDArray[np.float64] = None - self.mean_ells: None | npt.NDArray[np.float64] = None - self.ells_for_xi: None | npt.NDArray[np.int64] = None + def _init_empty_default_attribs(self): + """Initialize the empty and default attributes.""" + self.ell_for_xi_config = copy.deepcopy(ELL_FOR_XI_DEFAULTS) + self.ell_or_theta_config = None + self.ell_or_theta_min = None + self.ell_or_theta_max = None + + self.window = None - self.cells: dict[TracerNames, npt.NDArray[np.float64]] = {} + self.data_vector = None + self.theory_vector = None + + self.ells = None + self.thetas = None + self.mean_ells = None + self.ells_for_xi = None + + self.cells = {} + + def _set_ccl_kind(self, sacc_data_type): + """Set the CCL kind for this statistic.""" + self.sacc_data_type = sacc_data_type if self.sacc_data_type in SACC_DATA_TYPE_TO_CCL_KIND: self.ccl_kind = SACC_DATA_TYPE_TO_CCL_KIND[self.sacc_data_type] else: - raise ValueError( - f"The SACC data type {sacc_data_type}'%s' is not " f"supported!" + raise ValueError(f"The SACC data type {sacc_data_type} is not supported!") + + @classmethod + def _from_metadata( + cls, + *, + sacc_data_type: str, + source0: Source, + source1: Source, + metadata: TwoPointCells | TwoPointCWindow | TwoPointXiTheta, + ) -> TwoPoint: + """Create a TwoPoint statistic from a TwoPointCells metadata object.""" + two_point = cls(sacc_data_type, source0, source1) + match metadata: + case TwoPointCells(): + two_point._init_from_cells(metadata) + case TwoPointCWindow(): + two_point._init_from_cwindow(metadata) + case TwoPointXiTheta(): + two_point._init_from_xi_theta(metadata) + case _: + raise ValueError(f"Metadata of type {type(metadata)} is not supported!") + return two_point + + def _init_from_cells(self, metadata: TwoPointCells): + """Initialize the TwoPoint statistic from a TwoPointCells metadata object.""" + self.sacc_tracers = metadata.XY.get_tracer_names() + self.ells = metadata.ells + self.window = None + if metadata.Cell is not None: + self.sacc_indices = metadata.Cell.indices + self.data_vector = DataVector.create(metadata.Cell.data) + self.ready = True + + def _init_from_cwindow(self, metadata: TwoPointCWindow): + """Initialize the TwoPoint statistic from a TwoPointCWindow metadata object.""" + self.sacc_tracers = metadata.XY.get_tracer_names() + self.window = metadata.window + if self.window.ells_for_interpolation is None: + self.window.ells_for_interpolation = calculate_ells_for_interpolation( + self.window ) + if metadata.Cell is not None: + self.sacc_indices = metadata.Cell.indices + self.data_vector = DataVector.create(metadata.Cell.data) + self.ready = True + + def _init_from_xi_theta(self, metadata: TwoPointXiTheta): + """Initialize the TwoPoint statistic from a TwoPointXiTheta metadata object.""" + self.sacc_tracers = metadata.XY.get_tracer_names() + self.thetas = metadata.thetas + self.window = None + self.ells_for_xi = _ell_for_xi(**self.ell_for_xi_config) + if metadata.xis is not None: + self.sacc_indices = metadata.xis.indices + self.data_vector = DataVector.create(metadata.xis.data) + self.ready = True @classmethod - def from_metadata_cells( + def _from_metadata_any( cls, - metadata: list[TwoPointCells], + metadata: Sequence[TwoPointCells | TwoPointCWindow | TwoPointXiTheta], wl_factory: WeakLensingFactory | None = None, nc_factory: NumberCountsFactory | None = None, ) -> UpdatableCollection[TwoPoint]: - """Create a TwoPoint statistic from a TwoPointCells metadata object.""" - sacc_data = create_sacc(metadata) + """Create an UpdatableCollection of TwoPoint statistics. + This constructor creates an UpdatableCollection of TwoPoint statistics from a + list of TwoPointCells, TwoPointCWindow or TwoPointXiTheta metadata objects. + The metadata objects are used to initialize the TwoPoint statistics. + """ two_point_list = [ - cls( - cell.get_sacc_name(), - use_source_factory( - cell.XY.x, wl_factory=wl_factory, nc_factory=nc_factory + cls._from_metadata( + sacc_data_type=cell.get_sacc_name(), + source0=use_source_factory( + cell.XY.x, + cell.XY.x_measurement, + wl_factory=wl_factory, + nc_factory=nc_factory, ), - use_source_factory( - cell.XY.y, wl_factory=wl_factory, nc_factory=nc_factory + source1=use_source_factory( + cell.XY.y, + cell.XY.y_measurement, + wl_factory=wl_factory, + nc_factory=nc_factory, ), + metadata=cell, ) for cell in metadata ] - for two_point in two_point_list: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - two_point.read(sacc_data) + return UpdatableCollection(two_point_list) + + @classmethod + def _from_metadata_only_any( + cls, + metadata: Sequence[TwoPointCellsIndex | TwoPointXiThetaIndex], + wl_factory: WeakLensingFactory | None = None, + nc_factory: NumberCountsFactory | None = None, + ) -> UpdatableCollection[TwoPoint]: + """Create an UpdatableCollection of TwoPoint statistics. + + This constructor creates an UpdatableCollection of TwoPoint statistics from a + list of TwoPointCells, TwoPointCWindow or TwoPointXiTheta metadata objects. + The metadata objects are used to initialize the TwoPoint statistics. + """ + two_point_list = [] + for cell_index in metadata: + n1, a, n2, b = measurements_from_index(cell_index) + two_point = cls( + sacc_data_type=cell_index["data_type"], + source0=use_source_factory_metadata_only( + n1, a, wl_factory=wl_factory, nc_factory=nc_factory + ), + source1=use_source_factory_metadata_only( + n2, b, wl_factory=wl_factory, nc_factory=nc_factory + ), + ) + two_point_list.append(two_point) return UpdatableCollection(two_point_list) + @classmethod + def from_metadata_harmonic( + cls, + metadata: Sequence[TwoPointCells | TwoPointCWindow], + wl_factory: WeakLensingFactory | None = None, + nc_factory: NumberCountsFactory | None = None, + check_consistence: bool = False, + ) -> UpdatableCollection[TwoPoint]: + """Create an UpdatableCollection of harmonic space TwoPoint statistics. + + This constructor creates an UpdatableCollection of TwoPoint statistics from a + list of TwoPointCells or TwoPointCWindow metadata objects. The metadata objects + are used to initialize the TwoPoint statistics. + + :param metadata: The metadata objects to initialize the TwoPoint statistics. + :param wl_factory: The weak lensing factory to use. + :param nc_factory: The number counts factory to use. + :param check_consistence: Whether to check the consistence of the + metadata and data. + """ + if check_consistence: + check_two_point_consistence_harmonic(metadata) + return cls._from_metadata_any(metadata, wl_factory, nc_factory) + + @classmethod + def from_metadata_real( + cls, + metadata: Sequence[TwoPointXiTheta], + wl_factory: WeakLensingFactory | None = None, + nc_factory: NumberCountsFactory | None = None, + check_consistence: bool = False, + ) -> UpdatableCollection[TwoPoint]: + """Create an UpdatableCollection of real space TwoPoint statistics. + + This constructor creates an UpdatableCollection of TwoPoint statistics from a + list of TwoPointXiTheta metadata objects. The metadata objects are used to + initialize the TwoPoint statistics. + + :param metadata: The metadata objects to initialize the TwoPoint statistics. + :param wl_factory: The weak lensing factory to use. + :param nc_factory: The number counts factory to use. + :param check_consistence: Whether to check the consistence of the + metadata and data. + """ + if check_consistence: + check_two_point_consistence_real(metadata) + return cls._from_metadata_any(metadata, wl_factory, nc_factory) + + @classmethod + def from_metadata_only_harmonic( + cls, + metadata: list[TwoPointCellsIndex], + wl_factory: WeakLensingFactory | None = None, + nc_factory: NumberCountsFactory | None = None, + ) -> UpdatableCollection[TwoPoint]: + """Create a TwoPoint from metadata only.""" + return cls._from_metadata_only_any(metadata, wl_factory, nc_factory) + + @classmethod + def from_metadata_only_real( + cls, + metadata: list[TwoPointXiThetaIndex], + wl_factory: WeakLensingFactory | None = None, + nc_factory: NumberCountsFactory | None = None, + ) -> UpdatableCollection[TwoPoint]: + """Create a TwoPoint from metadata only.""" + return cls._from_metadata_only_any(metadata, wl_factory, nc_factory) + def read_ell_cells( self, sacc_data_type: str, sacc_data: sacc.Sacc, tracers: TracerNames ) -> ( @@ -626,7 +823,6 @@ def compute_theory_vector_real_space(self, tools: ModelingTools) -> TheoryVector theta=self.thetas / 60, type=self.ccl_kind, ) - assert self.data_vector is not None return TheoryVector.create(theory_vector) def compute_theory_vector_harmonic_space( @@ -644,7 +840,7 @@ def compute_theory_vector_harmonic_space( scale1 = self.source1.get_scale() assert self.ccl_kind == "cl" - assert self.ells is not None + assert (self.ells is not None) or (self.window is not None) if self.window is not None: # If a window function is provided, we need to compute the Cl's @@ -697,7 +893,6 @@ def compute_theory_vector_harmonic_space( tracers1, ) - assert self.data_vector is not None return TheoryVector.create(theory_vector) def _compute_theory_vector(self, tools: ModelingTools) -> TheoryVector: @@ -737,7 +932,6 @@ def compute_cells( * scale0 * scale1 ) - # Add up all the contributions to the cells self.cells[TRACER_NAMES_TOTAL] = np.array(sum(self.cells.values())) theory_vector = self.cells[TRACER_NAMES_TOTAL] return theory_vector diff --git a/firecrown/likelihood/weak_lensing.py b/firecrown/likelihood/weak_lensing.py index 49165950..b1102df8 100644 --- a/firecrown/likelihood/weak_lensing.py +++ b/firecrown/likelihood/weak_lensing.py @@ -358,14 +358,14 @@ class MultiplicativeShearBiasFactory(BaseModel): Field(description="The type of the systematic."), ] = "MultiplicativeShearBiasFactory" - def create(self, inferred_zdist: InferredGalaxyZDist) -> MultiplicativeShearBias: + def create(self, bin_name: str) -> MultiplicativeShearBias: """Create a MultiplicativeShearBias object. :param inferred_zdist: The inferred galaxy redshift distribution for the created MultiplicativeShearBias object. :return: The created MultiplicativeShearBias object. """ - return MultiplicativeShearBias(inferred_zdist.bin_name) + return MultiplicativeShearBias(bin_name) def create_global(self) -> MultiplicativeShearBias: """Create a MultiplicativeShearBias object. @@ -387,14 +387,14 @@ class LinearAlignmentSystematicFactory(BaseModel): alphag: None | float = 1.0 - def create(self, inferred_zdist: InferredGalaxyZDist) -> LinearAlignmentSystematic: + def create(self, bin_name: str) -> LinearAlignmentSystematic: """Create a LinearAlignmentSystematic object. :param inferred_zdist: The inferred galaxy redshift distribution for the created LinearAlignmentSystematic object. :return: The created LinearAlignmentSystematic object. """ - return LinearAlignmentSystematic(inferred_zdist.bin_name) + return LinearAlignmentSystematic(bin_name) def create_global(self) -> LinearAlignmentSystematic: """Create a LinearAlignmentSystematic object. @@ -414,14 +414,14 @@ class TattAlignmentSystematicFactory(BaseModel): Field(description="The type of the systematic."), ] = "TattAlignmentSystematicFactory" - def create(self, inferred_zdist: InferredGalaxyZDist) -> TattAlignmentSystematic: + def create(self, bin_name: str) -> TattAlignmentSystematic: """Create a TattAlignmentSystematic object. :param inferred_zdist: The inferred galaxy redshift distribution for the created TattAlignmentSystematic object. :return: The created TattAlignmentSystematic object. """ - return TattAlignmentSystematic(inferred_zdist.bin_name) + return TattAlignmentSystematic(bin_name) def create_global(self) -> TattAlignmentSystematic: """Create a TattAlignmentSystematic object. @@ -440,14 +440,14 @@ class PhotoZShiftFactory(BaseModel): Literal["PhotoZShiftFactory"], Field(description="The type of the systematic.") ] = "PhotoZShiftFactory" - def create(self, inferred_zdist: InferredGalaxyZDist) -> PhotoZShift: + def create(self, bin_name: str) -> PhotoZShift: """Create a PhotoZShift object. :param inferred_zdist: The inferred galaxy redshift distribution for the created PhotoZShift object. :return: The created PhotoZShift object. """ - return PhotoZShift(inferred_zdist.bin_name) + return PhotoZShift(bin_name) def create_global(self) -> PhotoZShift: """Create a PhotoZShift object. @@ -494,7 +494,7 @@ def create(self, inferred_zdist: InferredGalaxyZDist) -> WeakLensing: return self._cache[inferred_zdist_id] systematics: list[SourceGalaxySystematic[WeakLensingArgs]] = [ - systematic_factory.create(inferred_zdist) + systematic_factory.create(inferred_zdist.bin_name) for systematic_factory in self.per_bin_systematics ] systematics.extend(self._global_systematics_instances) @@ -503,3 +503,22 @@ def create(self, inferred_zdist: InferredGalaxyZDist) -> WeakLensing: self._cache[inferred_zdist_id] = wl return wl + + def create_from_metadata_only( + self, + sacc_tracer: str, + ) -> WeakLensing: + """Create an WeakLensing object with the given tracer name and scale.""" + sacc_tracer_id = hash(sacc_tracer) # Improve this + if sacc_tracer_id in self._cache: + return self._cache[sacc_tracer_id] + systematics: list[SourceGalaxySystematic[WeakLensingArgs]] = [ + systematic_factory.create(sacc_tracer) + for systematic_factory in self.per_bin_systematics + ] + systematics.extend(self._global_systematics_instances) + + wl = WeakLensing(sacc_tracer=sacc_tracer, systematics=systematics) + self._cache[sacc_tracer_id] = wl + + return wl diff --git a/firecrown/metadata/two_point.py b/firecrown/metadata/two_point.py index 29b2bd24..1930bd5a 100644 --- a/firecrown/metadata/two_point.py +++ b/firecrown/metadata/two_point.py @@ -4,531 +4,32 @@ metadata from a sacc file. """ -from itertools import combinations_with_replacement, chain -from typing import TypedDict, TypeVar, Type -from dataclasses import dataclass -from enum import Enum, auto -import re +from itertools import combinations_with_replacement, product +import hashlib +from typing import TypedDict, Sequence import numpy as np import numpy.typing as npt - -import yaml -from yaml import CLoader as Loader -from yaml import CDumper as Dumper - import sacc from sacc.data_types import required_tags -from firecrown.utils import compare_optional_arrays - -ST = TypeVar("ST") # This will be used in YAMLSerializable - - -class YAMLSerializable: - """Protocol for classes that can be serialized to and from YAML.""" - - def to_yaml(self: ST) -> str: - """Return the YAML representation of the object.""" - return yaml.dump(self, Dumper=Dumper, sort_keys=False) - - @classmethod - def from_yaml(cls: Type[ST], yaml_str: str) -> ST: - """Load the object from YAML.""" - return yaml.load(yaml_str, Loader=Loader) - - -@dataclass(frozen=True) -class TracerNames(YAMLSerializable): - """The names of the two tracers in the sacc file.""" - - name1: str - name2: str - - def __getitem__(self, item): - """Get the name of the tracer at the given index.""" - if item == 0: - return self.name1 - if item == 1: - return self.name2 - raise IndexError - - def __iter__(self): - """Iterate through the data members. - - This is to allow automatic unpacking. - """ - yield self.name1 - yield self.name2 - - -TRACER_NAMES_TOTAL = TracerNames("", "") # special name to represent total - - -class Galaxies(YAMLSerializable, str, Enum): - """This enumeration type for galaxy measurements. - - It provides identifiers for the different types of galaxy-related types of - measurement. - - SACC has some notion of supporting other types, but incomplete implementation. When - support for more types is added to SACC this enumeration needs to be updated. - """ - - SHEAR_E = auto() - SHEAR_T = auto() - COUNTS = auto() - - def sacc_type_name(self) -> str: - """Return the lower-case form of the main measurement type. - - This is the first part of the SACC string used to denote a correlation between - measurements of this type. - """ - return "galaxy" - - def sacc_measurement_name(self) -> str: - """Return the lower-case form of the specific measurement type. - - This is the second part of the SACC string used to denote the specific - measurement type. - """ - if self == Galaxies.SHEAR_E: - return "shear" - if self == Galaxies.SHEAR_T: - return "shear" - if self == Galaxies.COUNTS: - return "density" - raise ValueError("Untranslated Galaxy Measurement encountered") - - def polarization(self) -> str: - """Return the SACC polarization code. - - This is the third part of the SACC string used to denote the specific - measurement type. - """ - if self == Galaxies.SHEAR_E: - return "e" - if self == Galaxies.SHEAR_T: - return "t" - if self == Galaxies.COUNTS: - return "" - raise ValueError("Untranslated Galaxy Measurement encountered") - - def __lt__(self, other): - """Define a comparison function for the Galaxy Measurement enumeration.""" - return compare_enums(self, other) < 0 - - def __eq__(self, other): - """Define an equality test for Galaxy Measurement enumeration.""" - return compare_enums(self, other) == 0 - - def __ne__(self, other): - """Negation of __eq__.""" - return not self.__eq__(other) - - def __hash__(self) -> int: - """Define a hash function that uses both type and value information.""" - return hash((Galaxies, self.value)) - - -class CMB(YAMLSerializable, str, Enum): - """This enumeration type for CMB measurements. - - It provides identifiers for the different types of CMB-related types of - measurement. - - SACC has some notion of supporting other types, but incomplete implementation. When - support for more types is added to SACC this enumeration needs to be updated. - """ - - CONVERGENCE = auto() - - def sacc_type_name(self) -> str: - """Return the lower-case form of the main measurement type. - - This is the first part of the SACC string used to denote a correlation between - measurements of this type. - """ - return "cmb" - - def sacc_measurement_name(self) -> str: - """Return the lower-case form of the specific measurement type. - - This is the second part of the SACC string used to denote the specific - measurement type. - """ - if self == CMB.CONVERGENCE: - return "convergence" - raise ValueError("Untranslated CMBMeasurement encountered") - - def polarization(self) -> str: - """Return the SACC polarization code. - - This is the third part of the SACC string used to denote the specific - measurement type. - """ - if self == CMB.CONVERGENCE: - return "" - raise ValueError("Untranslated CMBMeasurement encountered") - - def __lt__(self, other): - """Define a comparison function for the CMBMeasurement enumeration.""" - return compare_enums(self, other) < 0 - - def __eq__(self, other): - """Define an equality test for CMBMeasurement enumeration.""" - return compare_enums(self, other) == 0 - - def __ne__(self, other): - """Negation of __eq__.""" - return not self.__eq__(other) - - def __hash__(self) -> int: - """Define a hash function that uses both type and value information.""" - return hash((CMB, self.value)) - - -class Clusters(YAMLSerializable, str, Enum): - """This enumeration type for cluster measurements. - - It provides identifiers for the different types of cluster-related types of - measurement. - - SACC has some notion of supporting other types, but incomplete implementation. When - support for more types is added to SACC this enumeration needs to be updated. - """ - - COUNTS = auto() - - def sacc_type_name(self) -> str: - """Return the lower-case form of the main measurement type. - - This is the first part of the SACC string used to denote a correlation between - measurements of this type. - """ - return "cluster" - - def sacc_measurement_name(self) -> str: - """Return the lower-case form of the specific measurement type. - - This is the second part of the SACC string used to denote the specific - measurement type. - """ - if self == Clusters.COUNTS: - return "density" - raise ValueError("Untranslated ClusterMeasurement encountered") - - def polarization(self) -> str: - """Return the SACC polarization code. - - This is the third part of the SACC string used to denote the specific - measurement type. - """ - if self == Clusters.COUNTS: - return "" - raise ValueError("Untranslated ClusterMeasurement encountered") - - def __lt__(self, other): - """Define a comparison function for the ClusterMeasurement enumeration.""" - return compare_enums(self, other) < 0 - - def __eq__(self, other): - """Define an equality test for ClusterMeasurement enumeration.""" - return compare_enums(self, other) == 0 - - def __ne__(self, other): - """Negation of __eq__.""" - return not self.__eq__(other) - - def __hash__(self) -> int: - """Define a hash function that uses both type and value information.""" - return hash((Clusters, self.value)) - - -Measurement = Galaxies | CMB | Clusters -ALL_MEASUREMENTS: list[Measurement] = list(chain(Galaxies, CMB, Clusters)) -ALL_MEASUREMENT_TYPES = (Galaxies, CMB, Clusters) -HARMONIC_ONLY_MEASUREMENTS = (Galaxies.SHEAR_E,) -REAL_ONLY_MEASUREMENTS = (Galaxies.SHEAR_T,) - - -def make_measurement_dict(value: Measurement) -> dict[str, str]: - """Create a dictionary from a Measurement object. - - :param value: the measurement to turn into a dictionary - """ - return {"subject": type(value).__name__, "property": value.name} - - -def measurement_is_compatible(a: Measurement, b: Measurement) -> bool: - """Check if two Measurement are compatible. - - Two Measurement are compatible if they can be correlated in a two-point function. - """ - if a in HARMONIC_ONLY_MEASUREMENTS and b in REAL_ONLY_MEASUREMENTS: - return False - if a in REAL_ONLY_MEASUREMENTS and b in HARMONIC_ONLY_MEASUREMENTS: - return False - return True - - -def measurement_supports_real(x: Measurement) -> bool: - """Return True if x supports real-space calculations.""" - return x not in HARMONIC_ONLY_MEASUREMENTS - - -def measurement_supports_harmonic(x: Measurement) -> bool: - """Return True if x supports harmonic-space calculations.""" - return x not in REAL_ONLY_MEASUREMENTS - - -def compare_enums(a: Measurement, b: Measurement) -> int: - """Define a comparison function for the Measurement enumeration. - - Return -1 if a comes before b, 0 if they are the same, and +1 if b comes before a. - """ - order = (CMB, Clusters, Galaxies) - main_type_index_a = order.index(type(a)) - main_type_index_b = order.index(type(b)) - if main_type_index_a == main_type_index_b: - return int(a) - int(b) - return main_type_index_a - main_type_index_b - - -@dataclass(frozen=True, kw_only=True) -class InferredGalaxyZDist(YAMLSerializable): - """The class used to store the redshift resolution data for a sacc file. - - The sacc file is a complicated set of tracers (bins) and surveys. This class is - used to store the redshift resolution data for a single photometric bin. - """ - - bin_name: str - z: np.ndarray - dndz: np.ndarray - measurement: Measurement - - def __post_init__(self) -> None: - """Validate the redshift resolution data. - - - Make sure the z and dndz arrays have the same shape; - - The measurement must be of type Measurement. - - The bin_name should not be empty. - """ - if self.z.shape != self.dndz.shape: - raise ValueError("The z and dndz arrays should have the same shape.") - - if not isinstance(self.measurement, ALL_MEASUREMENT_TYPES): - raise ValueError("The measurement should be a Measurement.") - - if self.bin_name == "": - raise ValueError("The bin_name should not be empty.") - - def __eq__(self, other): - """Equality test for InferredGalaxyZDist. - - Two InferredGalaxyZDist are equal if they have equal bin_name, z, dndz, and - measurement. - """ - return ( - self.bin_name == other.bin_name - and np.array_equal(self.z, other.z) - and np.array_equal(self.dndz, other.dndz) - and self.measurement == other.measurement - ) - - -@dataclass(frozen=True, kw_only=True) -class TwoPointXY(YAMLSerializable): - """Class defining a two-point correlation pair of redshift resolutions. - - It is used to store the two redshift resolutions for the two bins being - correlated. - """ - - x: InferredGalaxyZDist - y: InferredGalaxyZDist - - def __post_init__(self) -> None: - """Make sure the two redshift resolutions are compatible.""" - if not measurement_is_compatible(self.x.measurement, self.y.measurement): - raise ValueError( - f"Measurements {self.x.measurement} and {self.y.measurement} " - f"are not compatible." - ) - - def __eq__(self, other) -> bool: - """Equality test for TwoPointXY objects.""" - return self.x == other.x and self.y == other.y - - -@dataclass(frozen=True, kw_only=True) -class TwoPointCells(YAMLSerializable): - """Class defining the metadata for an harmonic-space two-point measurement. - - The class used to store the metadata for a (spherical) harmonic-space two-point - function measured on a sphere. - - This includes the two redshift resolutions (one for each binned quantity) and the - array of (integer) l's at which the two-point function which has this metadata were - calculated. - """ - - XY: TwoPointXY - ells: npt.NDArray[np.int64] - - def __post_init__(self) -> None: - """Validate the TwoPointCells data. - - Make sure the ells are a 1D array and X and Y are compatible - with harmonic-space calculations. - """ - if len(self.ells.shape) != 1: - raise ValueError("Ells should be a 1D array.") - - if not measurement_supports_harmonic( - self.XY.x.measurement - ) or not measurement_supports_harmonic(self.XY.y.measurement): - raise ValueError( - f"Measurements {self.XY.x.measurement} and " - f"{self.XY.y.measurement} must support harmonic-space calculations." - ) - - def __eq__(self, other) -> bool: - """Equality test for TwoPointCells objects.""" - return self.XY == other.XY and np.array_equal(self.ells, other.ells) - - def get_sacc_name(self) -> str: - """Return the SACC name for the two-point function.""" - return type_to_sacc_string_harmonic( - self.XY.x.measurement, self.XY.y.measurement - ) - - -@dataclass(kw_only=True) -class Window(YAMLSerializable): - """The class used to represent a window function. - - It contains the ells at which the window function is defined, the weights - of the window function, and the ells at which the window function is - interpolated. - - It may contain the ells for interpolation if the theory prediction is - calculated at a different set of ells than the window function. - """ - - ells: npt.NDArray[np.int64] - weights: npt.NDArray[np.float64] - ells_for_interpolation: None | npt.NDArray[np.int64] = None - - def __post_init__(self) -> None: - """Make sure the weights have the right shape.""" - if len(self.ells.shape) != 1: - raise ValueError("Ells should be a 1D array.") - if len(self.weights.shape) != 2: - raise ValueError("Weights should be a 2D array.") - if self.weights.shape[0] != len(self.ells): - raise ValueError("Weights should have the same number of rows as ells.") - if ( - self.ells_for_interpolation is not None - and len(self.ells_for_interpolation.shape) != 1 - ): - raise ValueError("Ells for interpolation should be a 1D array.") - - def n_observations(self) -> int: - """Return the number of observations supported by the window function.""" - return self.weights.shape[1] - - def __eq__(self, other) -> bool: - """Equality test for Window objects.""" - assert isinstance(other, Window) - # We will need special handling for the optional ells_for_interpolation. - # First handle the non-optinal parts. - partial_result = np.array_equal(self.ells, other.ells) and np.array_equal( - self.weights, other.weights - ) - if not partial_result: - return False - return compare_optional_arrays( - self.ells_for_interpolation, other.ells_for_interpolation - ) - - -@dataclass(frozen=True, kw_only=True) -class TwoPointCWindow(YAMLSerializable): - """Two-point function with a window function. - - The class used to store the metadata for a (spherical) harmonic-space two-point - function measured on a sphere, with an associated window function. - - This includes the two redshift resolutions (one for each binned quantity) and the - matrix (window function) that relates the measured Cl's with the predicted Cl's. - - Note that the matrix `window` always has l=0 and l=1 suppressed. - """ - - XY: TwoPointXY - window: Window - - def __post_init__(self): - """Validate the TwoPointCWindow data. - - Make sure the window is - """ - if not isinstance(self.window, Window): - raise ValueError("Window should be a Window object.") - - if not measurement_supports_harmonic( - self.XY.x.measurement - ) or not measurement_supports_harmonic(self.XY.y.measurement): - raise ValueError( - f"Measurements {self.XY.x.measurement} and " - f"{self.XY.y.measurement} must support harmonic-space calculations." - ) - - def get_sacc_name(self) -> str: - """Return the SACC name for the two-point function.""" - return type_to_sacc_string_harmonic( - self.XY.x.measurement, self.XY.y.measurement - ) - - -@dataclass(frozen=True, kw_only=True) -class TwoPointXiTheta(YAMLSerializable): - """Class defining the metadata for a real-space two-point measurement. - - The class used to store the metadata for a real-space two-point function measured - on a sphere. - - This includes the two redshift resolutions (one for each binned quantity) and the a - array of (floating point) theta (angle) values at which the two-point function - which has this metadata were calculated. - """ - - XY: TwoPointXY - thetas: npt.NDArray[np.float64] - - def __post_init__(self): - """Validate the TwoPointCWindow data. - - Make sure the window is - """ - if not measurement_supports_real( - self.XY.x.measurement - ) or not measurement_supports_real(self.XY.y.measurement): - raise ValueError( - f"Measurements {self.XY.x.measurement} and " - f"{self.XY.y.measurement} must support real-space calculations." - ) - - def get_sacc_name(self) -> str: - """Return the SACC name for the two-point function.""" - return type_to_sacc_string_real(self.XY.x.measurement, self.XY.y.measurement) - - def __eq__(self, other) -> bool: - """Equality test for TwoPointXiTheta objects.""" - return self.XY == other.XY and np.array_equal(self.thetas, other.thetas) +from firecrown.metadata.two_point_types import ( + TracerNames, + Measurement, + InferredGalaxyZDist, + TwoPointXY, + Window, + TwoPointCells, + TwoPointCWindow, + TwoPointXiTheta, + TwoPointMeasurement, + LENS_REGEX, + SOURCE_REGEX, + MEASURED_TYPE_STRING_MAP, + measurement_is_compatible, + GALAXY_LENS_TYPES, + GALAXY_SOURCE_TYPES, +) # TwoPointXiThetaIndex is a type used to create intermediate objects when @@ -555,46 +56,112 @@ def __eq__(self, other) -> bool: ) -def _extract_candidate_data_types( - tracer_name: str, data_points: list[sacc.DataPoint] -) -> list[Measurement]: - """Extract the candidate Measurement for a tracer. +def _extract_all_candidate_data_types( + data_points: list[sacc.DataPoint], + include_maybe_types: bool = False, +) -> dict[str, set[Measurement]]: + """Extract all candidate Measurement from the data points. - An exception is raise if the tracer does not have any associated data points. + The candidate Measurement are the ones that appear in the data points. """ - tracer_associated_types = { - d.data_type for d in data_points if tracer_name in d.tracers + all_data_types: set[tuple[str, str, str]] = { + (d.data_type, d.tracers[0], d.tracers[1]) for d in data_points } - tracer_associated_types_len = len(tracer_associated_types) + sure_types, maybe_types = _extract_sure_and_maybe_types(all_data_types) - type_count: dict[Measurement, int] = { - measurement: 0 for measurement in ALL_MEASUREMENTS - } - for data_type in tracer_associated_types: + # Remove the sure types from the maybe types. + for tracer0, sure_types0 in sure_types.items(): + maybe_types[tracer0] -= sure_types0 + + # Filter maybe types. + for data_type, tracer1, tracer2 in all_data_types: + if data_type not in MEASURED_TYPE_STRING_MAP: + continue + a, b = MEASURED_TYPE_STRING_MAP[data_type] + + if a == b: + continue + + if a in sure_types[tracer1] and b in sure_types[tracer2]: + maybe_types[tracer1].discard(b) + maybe_types[tracer2].discard(a) + elif a in sure_types[tracer2] and b in sure_types[tracer1]: + maybe_types[tracer1].discard(a) + maybe_types[tracer2].discard(b) + + if include_maybe_types: + return { + tracer0: sure_types0 | maybe_types[tracer0] + for tracer0, sure_types0 in sure_types.items() + } + return sure_types + + +def _extract_sure_and_maybe_types(all_data_types): + sure_types: dict[str, set[Measurement]] = {} + maybe_types: dict[str, set[Measurement]] = {} + + for data_type, tracer1, tracer2 in all_data_types: + sure_types[tracer1] = set() + sure_types[tracer2] = set() + maybe_types[tracer1] = set() + maybe_types[tracer2] = set() + + # Getting the sure and maybe types for each tracer. + for data_type, tracer1, tracer2 in all_data_types: if data_type not in MEASURED_TYPE_STRING_MAP: continue a, b = MEASURED_TYPE_STRING_MAP[data_type] - if a != b: - type_count[a] += 1 - type_count[b] += 1 + if a == b: + sure_types[tracer1].update({a}) + sure_types[tracer2].update({a}) else: - type_count[a] += 1 + name_match, n1, a, n2, b = match_name_type(tracer1, tracer2, a, b) + if name_match: + sure_types[n1].update({a}) + sure_types[n2].update({b}) + if not name_match: + maybe_types[tracer1].update({a, b}) + maybe_types[tracer2].update({a, b}) + return sure_types, maybe_types + + +def match_name_type( + tracer1: str, + tracer2: str, + a: Measurement, + b: Measurement, + require_convetion: bool = False, +) -> tuple[bool, str, Measurement, str, Measurement]: + """Use the naming convention to assign the right measurement to each tracer.""" + for n1, n2 in ((tracer1, tracer2), (tracer2, tracer1)): + if LENS_REGEX.match(n1) and SOURCE_REGEX.match(n2): + if a in GALAXY_SOURCE_TYPES and b in GALAXY_LENS_TYPES: + return True, n1, b, n2, a + if b in GALAXY_SOURCE_TYPES and a in GALAXY_LENS_TYPES: + return True, n1, a, n2, b + raise ValueError( + "Invalid SACC file, tracer names do not respect " + "the naming convetion." + ) + if require_convetion: + if LENS_REGEX.match(tracer1) and LENS_REGEX.match(tracer2): + return False, tracer1, a, tracer2, b + if SOURCE_REGEX.match(tracer1) and SOURCE_REGEX.match(tracer2): + return False, tracer1, a, tracer2, b - result = [ - measurement - for measurement, count in type_count.items() - if count == tracer_associated_types_len - ] - if len(result) == 0: raise ValueError( - f"Tracer {tracer_name} does not have data points associated with it. " - f"Inconsistent SACC object." + f"Invalid tracer names ({tracer1}, {tracer2}) " + f"do not respect the naming convetion." ) - return result + + return False, tracer1, a, tracer2, b -def extract_all_tracers(sacc_data: sacc.Sacc) -> list[InferredGalaxyZDist]: +def extract_all_tracers( + sacc_data: sacc.Sacc, include_maybe_types=False +) -> list[InferredGalaxyZDist]: """Extracts the two-point function metadata from a Sacc object. The Sacc object contains a set of tracers (one-dimensional bins) and data @@ -604,69 +171,42 @@ def extract_all_tracers(sacc_data: sacc.Sacc) -> list[InferredGalaxyZDist]: and returns it in a list. """ tracers: list[sacc.tracers.BaseTracer] = sacc_data.tracers.values() - - data_points = sacc_data.get_data_points() - - inferred_galaxy_zdists = [] - - for tracer in tracers: - candidate_measurements = _extract_candidate_data_types(tracer.name, data_points) - - measurement = extract_measurement(candidate_measurements, tracer) - - inferred_galaxy_zdists.append( - InferredGalaxyZDist( - bin_name=tracer.name, - z=tracer.z, - dndz=tracer.nz, - measurement=measurement, + tracer_types = extract_all_tracers_types( + sacc_data, include_maybe_types=include_maybe_types + ) + for tracer0, tracer_types0 in tracer_types.items(): + if len(tracer_types0) == 0: + raise ValueError( + f"Tracer {tracer0} does not have data points associated with it. " + f"Inconsistent SACC object." ) + + return [ + InferredGalaxyZDist( + bin_name=tracer.name, + z=tracer.z, + dndz=tracer.nz, + measurements=tracer_types[tracer.name], ) + for tracer in tracers + ] - return inferred_galaxy_zdists +def extract_all_tracers_types( + sacc_data: sacc.Sacc, + include_maybe_types: bool = False, +) -> dict[str, set[Measurement]]: + """Extracts the two-point function metadata from a Sacc object. -def extract_measurement( - candidate_measurements: list[Galaxies | CMB | Clusters], - tracer: sacc.tracers.BaseTracer, -) -> Galaxies | CMB | Clusters: - """Extract from tracer a single type of measurement. + The Sacc object contains a set of tracers (one-dimensional bins) and data + points (measurements of the correlation between two tracers). - Only types in candidate_measurements will be considered. + This function extracts the two-point function metadata from the Sacc object + and returns it in a list. """ - if len(candidate_measurements) == 1: - # Only one Measurement appears in all associated data points. - # We can infer the Measurement from the data points. - measurement = candidate_measurements[0] - else: - # We cannot infer the Measurement from the associated data points. - # We need to check the tracer name. - if LENS_REGEX.match(tracer.name): - if Galaxies.COUNTS not in candidate_measurements: - raise ValueError( - f"Tracer {tracer.name} matches the lens regex but does " - f"not have a compatible Measurement. Inconsistent SACC " - f"object." - ) - measurement = Galaxies.COUNTS - elif SOURCE_REGEX.match(tracer.name): - # The source tracers can be either shear E or shear T. - if Galaxies.SHEAR_E in candidate_measurements: - measurement = Galaxies.SHEAR_E - elif Galaxies.SHEAR_T in candidate_measurements: - measurement = Galaxies.SHEAR_T - else: - raise ValueError( - f"Tracer {tracer.name} matches the source regex but does " - f"not have a compatible Measurement. Inconsistent SACC " - f"object." - ) - else: - raise ValueError( - f"Tracer {tracer.name} does not have a compatible Measurement. " - f"Inconsistent SACC object." - ) - return measurement + data_points = sacc_data.get_data_points() + + return _extract_all_candidate_data_types(data_points, include_maybe_types) def extract_all_data_types_xi_thetas( @@ -755,9 +295,12 @@ def extract_all_data_types_cells( def extract_all_photoz_bin_combinations( sacc_data: sacc.Sacc, + include_maybe_types: bool = False, ) -> list[TwoPointXY]: """Extracts the two-point function metadata from a sacc file.""" - inferred_galaxy_zdists = extract_all_tracers(sacc_data) + inferred_galaxy_zdists = extract_all_tracers( + sacc_data, include_maybe_types=include_maybe_types + ) bin_combinations = make_all_photoz_bin_combinations(inferred_galaxy_zdists) return bin_combinations @@ -779,86 +322,251 @@ def extract_window_function( ) -LENS_REGEX = re.compile(r"^lens\d+$") -SOURCE_REGEX = re.compile(r"^(src\d+|source\d+)$") +def _build_two_point_xy( + inferred_galaxy_zdists_dict, tracer_names, data_type +) -> TwoPointXY: + """Build a TwoPointXY object from the inferred galaxy z distributions. + The TwoPointXY object is built from the inferred galaxy z distributions, the data + type, and the tracer names. + """ + a, b = MEASURED_TYPE_STRING_MAP[data_type] -def make_all_photoz_bin_combinations( - inferred_galaxy_zdists: list[InferredGalaxyZDist], -) -> list[TwoPointXY]: - """Extract the two-point function metadata from a sacc file.""" - bin_combinations = [ - TwoPointXY(x=igz1, y=igz2) - for igz1, igz2 in combinations_with_replacement(inferred_galaxy_zdists, 2) - if measurement_is_compatible(igz1.measurement, igz2.measurement) - ] + igz1 = inferred_galaxy_zdists_dict[tracer_names[0]] + igz2 = inferred_galaxy_zdists_dict[tracer_names[1]] - return bin_combinations + ab = a in igz1.measurements and b in igz2.measurements + ba = b in igz1.measurements and a in igz2.measurements + if a != b and ab and ba: + raise ValueError( + f"Ambiguous measurements for tracers {tracer_names}. " + f"Impossible to determine which measurement is from which tracer." + ) + XY = TwoPointXY( + x=igz1, y=igz2, x_measurement=a if ab else b, y_measurement=b if ab else a + ) + return XY -def _type_to_sacc_string_common(x: Measurement, y: Measurement) -> str: - """Return the first two parts of the SACC string. - The first two parts of the SACC string is used to denote a correlation between - measurements of x and y. - """ - a, b = sorted([x, y]) - if isinstance(a, type(b)): - part_1 = f"{a.sacc_type_name()}_" - if a == b: - part_2 = f"{a.sacc_measurement_name()}_" - else: - part_2 = ( - f"{a.sacc_measurement_name()}{b.sacc_measurement_name().capitalize()}_" +def extract_all_data_cells( + sacc_data: sacc.Sacc, + allowed_data_type: None | list[str] = None, + include_maybe_types=False, +) -> tuple[list[TwoPointCells], list[TwoPointCWindow]]: + """Extract the two-point function metadata and data from a sacc file.""" + inferred_galaxy_zdists_dict = { + igz.bin_name: igz + for igz in extract_all_tracers( + sacc_data, include_maybe_types=include_maybe_types + ) + } + + if sacc_data.covariance is None or sacc_data.covariance.dense is None: + raise ValueError("The SACC object does not have a covariance matrix.") + cov_hash = hashlib.sha256(sacc_data.covariance.dense).hexdigest() + + two_point_cells = [] + two_point_cwindows = [] + for cell_index in extract_all_data_types_cells(sacc_data, allowed_data_type): + tracer_names = cell_index["tracer_names"] + ells = cell_index["ells"] + data_type = cell_index["data_type"] + + XY = _build_two_point_xy(inferred_galaxy_zdists_dict, tracer_names, data_type) + + ells, Cells, indices = sacc_data.get_ell_cl( + data_type=data_type, + tracer1=tracer_names[0], + tracer2=tracer_names[1], + return_cov=False, + return_ind=True, + ) + + window = extract_window_function(sacc_data, indices) + if window is not None: + two_point_cwindows.append( + TwoPointCWindow( + XY=XY, + window=window, + Cell=TwoPointMeasurement( + data=Cells, + indices=indices, + covariance_name=cov_hash, + ), + ) ) - else: - part_1 = f"{a.sacc_type_name()}{b.sacc_type_name().capitalize()}_" - if a.sacc_measurement_name() == b.sacc_measurement_name(): - part_2 = f"{a.sacc_measurement_name()}_" else: - part_2 = ( - f"{a.sacc_measurement_name()}{b.sacc_measurement_name().capitalize()}_" + two_point_cells.append( + TwoPointCells( + XY=XY, + ells=ells, + Cell=TwoPointMeasurement( + data=Cells, + indices=indices, + covariance_name=cov_hash, + ), + ) ) - return part_1 + part_2 + return two_point_cells, two_point_cwindows + + +def extract_all_data_xi_thetas( + sacc_data: sacc.Sacc, + allowed_data_type: None | list[str] = None, + include_maybe_types=False, +) -> list[TwoPointXiTheta]: + """Extract the two-point function metadata and data from a sacc file.""" + inferred_galaxy_zdists_dict = { + igz.bin_name: igz + for igz in extract_all_tracers( + sacc_data, include_maybe_types=include_maybe_types + ) + } + + cov_hash = hashlib.sha256(sacc_data.covariance.dense).hexdigest() + + two_point_xi_thetas = [] + for xi_theta_index in extract_all_data_types_xi_thetas( + sacc_data, allowed_data_type + ): + tracer_names = xi_theta_index["tracer_names"] + thetas = xi_theta_index["thetas"] + data_type = xi_theta_index["data_type"] + + XY = _build_two_point_xy(inferred_galaxy_zdists_dict, tracer_names, data_type) + + thetas, Xis, indices = sacc_data.get_theta_xi( + data_type=data_type, + tracer1=tracer_names[0], + tracer2=tracer_names[1], + return_cov=False, + return_ind=True, + ) + + Xi = TwoPointMeasurement( + data=Xis, + indices=indices, + covariance_name=cov_hash, + ) + + two_point_xi_thetas.append(TwoPointXiTheta(XY=XY, thetas=thetas, xis=Xi)) + + return two_point_xi_thetas -def type_to_sacc_string_real(x: Measurement, y: Measurement) -> str: - """Return the final SACC string used to denote the real-space correlation. +def check_two_point_consistence_harmonic( + two_point_cells: Sequence[TwoPointCells | TwoPointCWindow], +) -> None: + """Check the indices of the harmonic-space two-point functions. - The SACC string used to denote the real-space correlation type - between measurements of x and y. + Make sure the indices of the harmonic-space two-point functions are consistent. """ - a, b = sorted([x, y]) - suffix = f"{a.polarization()}{b.polarization()}" + all_indices_set: set[int] = set() + index_set_list = [] + cov_name: None | str = None - if a in HARMONIC_ONLY_MEASUREMENTS or b in HARMONIC_ONLY_MEASUREMENTS: - raise ValueError("Real-space correlation not supported for shear E.") + for two_point_cell in two_point_cells: + if two_point_cell.Cell is None: + raise ValueError( + f"The TwoPointCells {two_point_cell} does not contain a data." + ) + if cov_name is None: + cov_name = two_point_cell.Cell.covariance_name + elif cov_name != two_point_cell.Cell.covariance_name: + raise ValueError( + f"The TwoPointCells {two_point_cell} has a different covariance name " + f"{two_point_cell.Cell.covariance_name} than the previous " + f"TwoPointCells {cov_name}." + ) + index_set = set(two_point_cell.Cell.indices) + index_set_list.append(index_set) + if len(index_set) != len(two_point_cell.Cell.indices): + raise ValueError( + f"The indices of the TwoPointCells {two_point_cell} are not unique." + ) - return _type_to_sacc_string_common(x, y) + (f"xi_{suffix}" if suffix else "xi") + if all_indices_set & index_set: + for i, index_set_a in enumerate(index_set_list): + if index_set_a & index_set: + raise ValueError( + f"The indices of the TwoPointCells {two_point_cells[i]} and " + f"{two_point_cell} overlap." + ) + all_indices_set.update(index_set) -def type_to_sacc_string_harmonic(x: Measurement, y: Measurement) -> str: - """Return the final SACC string used to denote the harmonic-space correlation. +def check_two_point_consistence_real( + two_point_xi_thetas: Sequence[TwoPointXiTheta], +) -> None: + """Check the indices of the real-space two-point functions. - the SACC string used to denote the harmonic-space correlation type - between measurements of x and y. + Make sure the indices of the real-space two-point functions are consistent. """ - a, b = sorted([x, y]) - suffix = f"{a.polarization()}{b.polarization()}" + all_indices_set: set[int] = set() + index_set_list = [] + cov_name: None | str = None + + for two_point_xi_theta in two_point_xi_thetas: + if two_point_xi_theta.xis is None: + raise ValueError( + f"The TwoPointXiTheta {two_point_xi_theta} does not contain a data." + ) + if cov_name is None: + cov_name = two_point_xi_theta.xis.covariance_name + elif cov_name != two_point_xi_theta.xis.covariance_name: + raise ValueError( + f"The TwoPointXiTheta {two_point_xi_theta} has a different covariance " + f"name {two_point_xi_theta.xis.covariance_name} than the previous " + f"TwoPointXiTheta {cov_name}." + ) + index_set = set(two_point_xi_theta.xis.indices) + index_set_list.append(index_set) + if len(index_set) != len(two_point_xi_theta.xis.indices): + raise ValueError( + f"The indices of the TwoPointXiTheta {two_point_xi_theta} " + f"are not unique." + ) - if a in REAL_ONLY_MEASUREMENTS or b in REAL_ONLY_MEASUREMENTS: - raise ValueError("Harmonic-space correlation not supported for shear T.") + if all_indices_set & index_set: + for i, index_set_a in enumerate(index_set_list): + if index_set_a & index_set: + raise ValueError( + f"The indices of the TwoPointXiTheta {two_point_xi_thetas[i]} " + f"and {two_point_xi_theta} overlap." + ) + all_indices_set.update(index_set) - return _type_to_sacc_string_common(x, y) + (f"cl_{suffix}" if suffix else "cl") +def make_all_photoz_bin_combinations( + inferred_galaxy_zdists: list[InferredGalaxyZDist], +) -> list[TwoPointXY]: + """Extract the two-point function metadata from a sacc file.""" + bin_combinations = [ + TwoPointXY( + x=igz1, y=igz2, x_measurement=x_measurement, y_measurement=y_measurement + ) + for igz1, igz2 in combinations_with_replacement(inferred_galaxy_zdists, 2) + for x_measurement, y_measurement in product( + igz1.measurements, igz2.measurements + ) + if measurement_is_compatible(x_measurement, y_measurement) + ] -MEASURED_TYPE_STRING_MAP: dict[str, tuple[Measurement, Measurement]] = { - type_to_sacc_string_real(a, b): (a, b) if a < b else (b, a) - for a, b in combinations_with_replacement(ALL_MEASUREMENTS, 2) - if measurement_supports_real(a) and measurement_supports_real(b) -} | { - type_to_sacc_string_harmonic(a, b): (a, b) if a < b else (b, a) - for a, b in combinations_with_replacement(ALL_MEASUREMENTS, 2) - if measurement_supports_harmonic(a) and measurement_supports_harmonic(b) -} + return bin_combinations + + +def measurements_from_index( + index: TwoPointXiThetaIndex | TwoPointCellsIndex, +) -> tuple[str, Measurement, str, Measurement]: + """Return the measurements from a TwoPointXiThetaIndex object.""" + a, b = MEASURED_TYPE_STRING_MAP[index["data_type"]] + _, n1, a, n2, b = match_name_type( + index["tracer_names"].name1, + index["tracer_names"].name2, + a, + b, + require_convetion=True, + ) + return n1, a, n2, b diff --git a/firecrown/metadata/two_point_types.py b/firecrown/metadata/two_point_types.py new file mode 100644 index 00000000..6b8f644c --- /dev/null +++ b/firecrown/metadata/two_point_types.py @@ -0,0 +1,752 @@ +"""This module deals with two-point types. + +This module contains two-point types definitions. +""" + +from itertools import chain, combinations_with_replacement +from dataclasses import dataclass +import re +from enum import Enum, auto + +import numpy as np +import numpy.typing as npt + +from firecrown.utils import compare_optional_arrays, compare_optionals, YAMLSerializable + + +@dataclass(frozen=True) +class TracerNames(YAMLSerializable): + """The names of the two tracers in the sacc file.""" + + name1: str + name2: str + + def __getitem__(self, item): + """Get the name of the tracer at the given index.""" + if item == 0: + return self.name1 + if item == 1: + return self.name2 + raise IndexError + + def __iter__(self): + """Iterate through the data members. + + This is to allow automatic unpacking. + """ + yield self.name1 + yield self.name2 + + +TRACER_NAMES_TOTAL = TracerNames("", "") # special name to represent total + + +class Galaxies(YAMLSerializable, str, Enum): + """This enumeration type for galaxy measurements. + + It provides identifiers for the different types of galaxy-related types of + measurement. + + SACC has some notion of supporting other types, but incomplete implementation. When + support for more types is added to SACC this enumeration needs to be updated. + """ + + SHEAR_E = auto() + SHEAR_T = auto() + SHEAR_MINUS = auto() + SHEAR_PLUS = auto() + COUNTS = auto() + + def sacc_type_name(self) -> str: + """Return the lower-case form of the main measurement type. + + This is the first part of the SACC string used to denote a correlation between + measurements of this type. + """ + return "galaxy" + + def sacc_measurement_name(self) -> str: + """Return the lower-case form of the specific measurement type. + + This is the second part of the SACC string used to denote the specific + measurement type. + """ + if self == Galaxies.SHEAR_E: + return "shear" + if self == Galaxies.SHEAR_T: + return "shear" + if self == Galaxies.SHEAR_MINUS: + return "shear" + if self == Galaxies.SHEAR_PLUS: + return "shear" + if self == Galaxies.COUNTS: + return "density" + raise ValueError("Untranslated Galaxy Measurement encountered") + + def polarization(self) -> str: + """Return the SACC polarization code. + + This is the third part of the SACC string used to denote the specific + measurement type. + """ + if self == Galaxies.SHEAR_E: + return "e" + if self == Galaxies.SHEAR_T: + return "t" + if self == Galaxies.SHEAR_MINUS: + return "minus" + if self == Galaxies.SHEAR_PLUS: + return "plus" + if self == Galaxies.COUNTS: + return "" + raise ValueError("Untranslated Galaxy Measurement encountered") + + def __lt__(self, other): + """Define a comparison function for the Galaxy Measurement enumeration.""" + return compare_enums(self, other) < 0 + + def __eq__(self, other): + """Define an equality test for Galaxy Measurement enumeration.""" + return compare_enums(self, other) == 0 + + def __ne__(self, other): + """Negation of __eq__.""" + return not self.__eq__(other) + + def __hash__(self) -> int: + """Define a hash function that uses both type and value information.""" + return hash((Galaxies, self.value)) + + +class CMB(YAMLSerializable, str, Enum): + """This enumeration type for CMB measurements. + + It provides identifiers for the different types of CMB-related types of + measurement. + + SACC has some notion of supporting other types, but incomplete implementation. When + support for more types is added to SACC this enumeration needs to be updated. + """ + + CONVERGENCE = auto() + + def sacc_type_name(self) -> str: + """Return the lower-case form of the main measurement type. + + This is the first part of the SACC string used to denote a correlation between + measurements of this type. + """ + return "cmb" + + def sacc_measurement_name(self) -> str: + """Return the lower-case form of the specific measurement type. + + This is the second part of the SACC string used to denote the specific + measurement type. + """ + if self == CMB.CONVERGENCE: + return "convergence" + raise ValueError("Untranslated CMBMeasurement encountered") + + def polarization(self) -> str: + """Return the SACC polarization code. + + This is the third part of the SACC string used to denote the specific + measurement type. + """ + if self == CMB.CONVERGENCE: + return "" + raise ValueError("Untranslated CMBMeasurement encountered") + + def __lt__(self, other): + """Define a comparison function for the CMBMeasurement enumeration.""" + return compare_enums(self, other) < 0 + + def __eq__(self, other): + """Define an equality test for CMBMeasurement enumeration.""" + return compare_enums(self, other) == 0 + + def __ne__(self, other): + """Negation of __eq__.""" + return not self.__eq__(other) + + def __hash__(self) -> int: + """Define a hash function that uses both type and value information.""" + return hash((CMB, self.value)) + + +class Clusters(YAMLSerializable, str, Enum): + """This enumeration type for cluster measurements. + + It provides identifiers for the different types of cluster-related types of + measurement. + + SACC has some notion of supporting other types, but incomplete implementation. When + support for more types is added to SACC this enumeration needs to be updated. + """ + + COUNTS = auto() + + def sacc_type_name(self) -> str: + """Return the lower-case form of the main measurement type. + + This is the first part of the SACC string used to denote a correlation between + measurements of this type. + """ + return "cluster" + + def sacc_measurement_name(self) -> str: + """Return the lower-case form of the specific measurement type. + + This is the second part of the SACC string used to denote the specific + measurement type. + """ + if self == Clusters.COUNTS: + return "density" + raise ValueError("Untranslated ClusterMeasurement encountered") + + def polarization(self) -> str: + """Return the SACC polarization code. + + This is the third part of the SACC string used to denote the specific + measurement type. + """ + if self == Clusters.COUNTS: + return "" + raise ValueError("Untranslated ClusterMeasurement encountered") + + def __lt__(self, other): + """Define a comparison function for the ClusterMeasurement enumeration.""" + return compare_enums(self, other) < 0 + + def __eq__(self, other): + """Define an equality test for ClusterMeasurement enumeration.""" + return compare_enums(self, other) == 0 + + def __ne__(self, other): + """Negation of __eq__.""" + return not self.__eq__(other) + + def __hash__(self) -> int: + """Define a hash function that uses both type and value information.""" + return hash((Clusters, self.value)) + + +Measurement = Galaxies | CMB | Clusters +ALL_MEASUREMENTS: list[Measurement] = list(chain(Galaxies, CMB, Clusters)) +ALL_MEASUREMENT_TYPES = (Galaxies, CMB, Clusters) +HARMONIC_ONLY_MEASUREMENTS = (Galaxies.SHEAR_E,) +REAL_ONLY_MEASUREMENTS = (Galaxies.SHEAR_T, Galaxies.SHEAR_MINUS, Galaxies.SHEAR_PLUS) +EXACT_MATCH_MEASUREMENTS = (Galaxies.SHEAR_MINUS, Galaxies.SHEAR_PLUS) +LENS_REGEX = re.compile(r"^lens\d+$") +SOURCE_REGEX = re.compile(r"^(src\d+|source\d+)$") +GALAXY_SOURCE_TYPES = ( + Galaxies.SHEAR_E, + Galaxies.SHEAR_T, + Galaxies.SHEAR_MINUS, + Galaxies.SHEAR_PLUS, +) +GALAXY_LENS_TYPES = (Galaxies.COUNTS,) + + +def compare_enums(a: Measurement, b: Measurement) -> int: + """Define a comparison function for the Measurement enumeration. + + Return -1 if a comes before b, 0 if they are the same, and +1 if b comes before a. + """ + order = (CMB, Clusters, Galaxies) + if type(a) not in order or type(b) not in order: + raise ValueError( + f"Unknown measurement type encountered ({type(a)}, {type(b)})." + ) + + main_type_index_a = order.index(type(a)) + main_type_index_b = order.index(type(b)) + if main_type_index_a == main_type_index_b: + return int(a) - int(b) + return main_type_index_a - main_type_index_b + + +@dataclass(frozen=True, kw_only=True) +class InferredGalaxyZDist(YAMLSerializable): + """The class used to store the redshift resolution data for a sacc file. + + The sacc file is a complicated set of tracers (bins) and surveys. This class is + used to store the redshift resolution data for a single photometric bin. + """ + + bin_name: str + z: np.ndarray + dndz: np.ndarray + measurements: set[Measurement] + + def __post_init__(self) -> None: + """Validate the redshift resolution data. + + - Make sure the z and dndz arrays have the same shape; + - The measurement must be of type Measurement. + - The bin_name should not be empty. + """ + if self.z.shape != self.dndz.shape: + raise ValueError("The z and dndz arrays should have the same shape.") + + for measurement in self.measurements: + if not isinstance(measurement, ALL_MEASUREMENT_TYPES): + raise ValueError("The measurement should be a Measurement.") + + if self.bin_name == "": + raise ValueError("The bin_name should not be empty.") + + def __eq__(self, other): + """Equality test for InferredGalaxyZDist. + + Two InferredGalaxyZDist are equal if they have equal bin_name, z, dndz, and + measurement. + """ + assert isinstance(other, InferredGalaxyZDist) + return ( + self.bin_name == other.bin_name + and np.array_equal(self.z, other.z) + and np.array_equal(self.dndz, other.dndz) + and self.measurements == other.measurements + ) + + +@dataclass(frozen=True, kw_only=True) +class TwoPointMeasurement(YAMLSerializable): + """Class defining the metadata for a two-point measurement. + + The class used to store the metadata for a two-point function measured on a sphere. + + This includes the measured two-point function and their indices in the covariance + matrix. + """ + + data: npt.NDArray[np.float64] + indices: npt.NDArray[np.int64] + covariance_name: str + + def __post_init__(self) -> None: + """Make sure the data and indices have the same shape.""" + if len(self.data.shape) != 1: + raise ValueError("Data should be a 1D array.") + + if self.data.shape != self.indices.shape: + raise ValueError("Data and indices should have the same shape.") + + def __eq__(self, other) -> bool: + """Equality test for TwoPointMeasurement objects.""" + return ( + np.array_equal(self.data, other.data) + and np.array_equal(self.indices, other.indices) + and self.covariance_name == other.covariance_name + ) + + +def make_measurements_dict(value: set[Measurement]) -> list[dict[str, str]]: + """Create a dictionary from a Measurement object. + + :param value: the measurement to turn into a dictionary + """ + return [ + {"subject": type(measurement).__name__, "property": measurement.name} + for measurement in value + ] + + +def measurement_is_compatible(a: Measurement, b: Measurement) -> bool: + """Check if two Measurement are compatible. + + Two Measurement are compatible if they can be correlated in a two-point function. + """ + if a in HARMONIC_ONLY_MEASUREMENTS and b in REAL_ONLY_MEASUREMENTS: + return False + if a in REAL_ONLY_MEASUREMENTS and b in HARMONIC_ONLY_MEASUREMENTS: + return False + if (a in EXACT_MATCH_MEASUREMENTS or b in EXACT_MATCH_MEASUREMENTS) and a != b: + return False + return True + + +def measurement_supports_real(x: Measurement) -> bool: + """Return True if x supports real-space calculations.""" + return x not in HARMONIC_ONLY_MEASUREMENTS + + +def measurement_supports_harmonic(x: Measurement) -> bool: + """Return True if x supports harmonic-space calculations.""" + return x not in REAL_ONLY_MEASUREMENTS + + +def measurement_is_compatible_real(a: Measurement, b: Measurement) -> bool: + """Check if two Measurement are compatible for real-space calculations. + + Two Measurement are compatible if they can be correlated in a real-space two-point + function. + """ + return ( + measurement_supports_real(a) + and measurement_supports_real(b) + and measurement_is_compatible(a, b) + ) + + +def measurement_is_compatible_harmonic(a: Measurement, b: Measurement) -> bool: + """Check if two Measurement are compatible for harmonic-space calculations. + + Two Measurement are compatible if they can be correlated in a harmonic-space + two-point function. + """ + return ( + measurement_supports_harmonic(a) + and measurement_supports_harmonic(b) + and measurement_is_compatible(a, b) + ) + + +def _type_to_sacc_string_common(x: Measurement, y: Measurement) -> str: + """Return the first two parts of the SACC string. + + The first two parts of the SACC string is used to denote a correlation between + measurements of x and y. + """ + a, b = sorted([x, y]) + if isinstance(a, type(b)): + part_1 = f"{a.sacc_type_name()}_" + if a == b: + part_2 = f"{a.sacc_measurement_name()}_" + else: + part_2 = ( + f"{a.sacc_measurement_name()}{b.sacc_measurement_name().capitalize()}_" + ) + else: + part_1 = f"{a.sacc_type_name()}{b.sacc_type_name().capitalize()}_" + if a.sacc_measurement_name() == b.sacc_measurement_name(): + part_2 = f"{a.sacc_measurement_name()}_" + else: + part_2 = ( + f"{a.sacc_measurement_name()}{b.sacc_measurement_name().capitalize()}_" + ) + + return part_1 + part_2 + + +def type_to_sacc_string_real(x: Measurement, y: Measurement) -> str: + """Return the final SACC string used to denote the real-space correlation. + + The SACC string used to denote the real-space correlation type + between measurements of x and y. + """ + a, b = sorted([x, y]) + if a in EXACT_MATCH_MEASUREMENTS: + assert a == b + suffix = f"{a.polarization()}" + else: + suffix = f"{a.polarization()}{b.polarization()}" + + if a in HARMONIC_ONLY_MEASUREMENTS or b in HARMONIC_ONLY_MEASUREMENTS: + raise ValueError("Real-space correlation not supported for shear E.") + + return _type_to_sacc_string_common(x, y) + (f"xi_{suffix}" if suffix else "xi") + + +def type_to_sacc_string_harmonic(x: Measurement, y: Measurement) -> str: + """Return the final SACC string used to denote the harmonic-space correlation. + + the SACC string used to denote the harmonic-space correlation type + between measurements of x and y. + """ + a, b = sorted([x, y]) + suffix = f"{a.polarization()}{b.polarization()}" + + if a in REAL_ONLY_MEASUREMENTS or b in REAL_ONLY_MEASUREMENTS: + raise ValueError("Harmonic-space correlation not supported for shear T.") + + return _type_to_sacc_string_common(x, y) + (f"cl_{suffix}" if suffix else "cl") + + +MEASURED_TYPE_STRING_MAP: dict[str, tuple[Measurement, Measurement]] = { + type_to_sacc_string_real(a, b): (a, b) if a < b else (b, a) + for a, b in combinations_with_replacement(ALL_MEASUREMENTS, 2) + if measurement_is_compatible_real(a, b) +} | { + type_to_sacc_string_harmonic(a, b): (a, b) if a < b else (b, a) + for a, b in combinations_with_replacement(ALL_MEASUREMENTS, 2) + if measurement_is_compatible_harmonic(a, b) +} + + +@dataclass(frozen=True, kw_only=True) +class TwoPointXY(YAMLSerializable): + """Class defining a two-point correlation pair of redshift resolutions. + + It is used to store the two redshift resolutions for the two bins being + correlated. + """ + + x: InferredGalaxyZDist + y: InferredGalaxyZDist + x_measurement: Measurement + y_measurement: Measurement + + def __post_init__(self) -> None: + """Make sure the two redshift resolutions are compatible.""" + if self.x_measurement not in self.x.measurements: + raise ValueError( + f"Measurement {self.x_measurement} not in the measurements of " + f"{self.x.bin_name}." + ) + if self.y_measurement not in self.y.measurements: + raise ValueError( + f"Measurement {self.y_measurement} not in the measurements of " + f"{self.y.bin_name}." + ) + if not measurement_is_compatible(self.x_measurement, self.y_measurement): + raise ValueError( + f"Measurements {self.x_measurement} and {self.y_measurement} " + f"are not compatible." + ) + + def __eq__(self, other) -> bool: + """Equality test for TwoPointXY objects.""" + return ( + self.x == other.x + and self.y == other.y + and self.x_measurement == other.x_measurement + and self.y_measurement == other.y_measurement + ) + + def __str__(self) -> str: + """Return a string representation of the TwoPointXY object.""" + return f"({self.x.bin_name}, {self.y.bin_name})" + + def get_tracer_names(self) -> TracerNames: + """Return the TracerNames object for the TwoPointXY object.""" + return TracerNames(self.x.bin_name, self.y.bin_name) + + +@dataclass(frozen=True, kw_only=True) +class TwoPointCells(YAMLSerializable): + """Class defining the metadata for an harmonic-space two-point measurement. + + The class used to store the metadata for a (spherical) harmonic-space two-point + function measured on a sphere. + + This includes the two redshift resolutions (one for each binned quantity) and the + array of (integer) l's at which the two-point function which has this metadata were + calculated. + """ + + XY: TwoPointXY + ells: npt.NDArray[np.int64] + Cell: None | TwoPointMeasurement = None + + def __post_init__(self) -> None: + """Validate the TwoPointCells data. + + Make sure the ells are a 1D array and X and Y are compatible + with harmonic-space calculations. + """ + if len(self.ells.shape) != 1: + raise ValueError("Ells should be a 1D array.") + + if self.Cell is not None and self.Cell.data.shape != self.ells.shape: + raise ValueError("Cell should have the same shape as ells.") + + if not measurement_supports_harmonic( + self.XY.x_measurement + ) or not measurement_supports_harmonic(self.XY.y_measurement): + raise ValueError( + f"Measurements {self.XY.x_measurement} and " + f"{self.XY.y_measurement} must support harmonic-space calculations." + ) + + def __eq__(self, other) -> bool: + """Equality test for TwoPointCells objects.""" + return ( + self.XY == other.XY + and np.array_equal(self.ells, other.ells) + and compare_optionals(self.Cell, other.Cell) + ) + + def __str__(self) -> str: + """Return a string representation of the TwoPointCells object.""" + return f"{self.XY}[{self.get_sacc_name()}]" + + def get_sacc_name(self) -> str: + """Return the SACC name for the two-point function.""" + return type_to_sacc_string_harmonic( + self.XY.x_measurement, self.XY.y_measurement + ) + + def has_data(self) -> bool: + """Return True if the TwoPointCells object has a Cell array.""" + return self.Cell is not None + + +@dataclass(kw_only=True) +class Window(YAMLSerializable): + """The class used to represent a window function. + + It contains the ells at which the window function is defined, the weights + of the window function, and the ells at which the window function is + interpolated. + + It may contain the ells for interpolation if the theory prediction is + calculated at a different set of ells than the window function. + """ + + ells: npt.NDArray[np.int64] + weights: npt.NDArray[np.float64] + ells_for_interpolation: None | npt.NDArray[np.int64] = None + + def __post_init__(self) -> None: + """Make sure the weights have the right shape.""" + if len(self.ells.shape) != 1: + raise ValueError("Ells should be a 1D array.") + if len(self.weights.shape) != 2: + raise ValueError("Weights should be a 2D array.") + if self.weights.shape[0] != len(self.ells): + raise ValueError("Weights should have the same number of rows as ells.") + if ( + self.ells_for_interpolation is not None + and len(self.ells_for_interpolation.shape) != 1 + ): + raise ValueError("Ells for interpolation should be a 1D array.") + + def n_observations(self) -> int: + """Return the number of observations supported by the window function.""" + return self.weights.shape[1] + + def __eq__(self, other) -> bool: + """Equality test for Window objects.""" + assert isinstance(other, Window) + # We will need special handling for the optional ells_for_interpolation. + # First handle the non-optinal parts. + partial_result = np.array_equal(self.ells, other.ells) and np.array_equal( + self.weights, other.weights + ) + if not partial_result: + return False + return compare_optional_arrays( + self.ells_for_interpolation, other.ells_for_interpolation + ) + + +@dataclass(frozen=True, kw_only=True) +class TwoPointCWindow(YAMLSerializable): + """Two-point function with a window function. + + The class used to store the metadata for a (spherical) harmonic-space two-point + function measured on a sphere, with an associated window function. + + This includes the two redshift resolutions (one for each binned quantity) and the + matrix (window function) that relates the measured Cl's with the predicted Cl's. + + Note that the matrix `window` always has l=0 and l=1 suppressed. + """ + + XY: TwoPointXY + window: Window + Cell: None | TwoPointMeasurement = None + + def __post_init__(self): + """Validate the TwoPointCWindow data. + + Make sure the window is + """ + if not isinstance(self.window, Window): + raise ValueError("Window should be a Window object.") + + if self.Cell is not None: + if len(self.Cell.data) != self.window.n_observations(): + raise ValueError( + "Data should have the same number of elements as the number of " + "observations supported by the window function." + ) + + if not measurement_supports_harmonic( + self.XY.x_measurement + ) or not measurement_supports_harmonic(self.XY.y_measurement): + raise ValueError( + f"Measurements {self.XY.x_measurement} and " + f"{self.XY.y_measurement} must support harmonic-space calculations." + ) + + def __str__(self) -> str: + """Return a string representation of the TwoPointCWindow object.""" + return f"{self.XY}[{self.get_sacc_name()}]" + + def get_sacc_name(self) -> str: + """Return the SACC name for the two-point function.""" + return type_to_sacc_string_harmonic( + self.XY.x_measurement, self.XY.y_measurement + ) + + def __eq__(self, other) -> bool: + """Equality test for TwoPointCWindow objects.""" + return ( + self.XY == other.XY + and self.window == other.window + and compare_optionals(self.Cell, other.Cell) + ) + + def has_data(self) -> bool: + """Return True if the TwoPointCWindow object has a Cell array.""" + return self.Cell is not None + + +@dataclass(frozen=True, kw_only=True) +class TwoPointXiTheta(YAMLSerializable): + """Class defining the metadata for a real-space two-point measurement. + + The class used to store the metadata for a real-space two-point function measured + on a sphere. + + This includes the two redshift resolutions (one for each binned quantity) and the a + array of (floating point) theta (angle) values at which the two-point function + which has this metadata were calculated. + """ + + XY: TwoPointXY + thetas: npt.NDArray[np.float64] + xis: None | TwoPointMeasurement = None + + def __post_init__(self): + """Validate the TwoPointCWindow data. + + Make sure the window is + """ + if len(self.thetas.shape) != 1: + raise ValueError("Thetas should be a 1D array.") + + if self.xis is not None and self.xis.data.shape != self.thetas.shape: + raise ValueError("Xis should have the same shape as thetas.") + + if not measurement_supports_real( + self.XY.x_measurement + ) or not measurement_supports_real(self.XY.y_measurement): + raise ValueError( + f"Measurements {self.XY.x_measurement} and " + f"{self.XY.y_measurement} must support real-space calculations." + ) + + def __str__(self) -> str: + """Return a string representation of the TwoPointXiTheta object.""" + return f"{self.XY}[{self.get_sacc_name()}]" + + def get_sacc_name(self) -> str: + """Return the SACC name for the two-point function.""" + return type_to_sacc_string_real(self.XY.x_measurement, self.XY.y_measurement) + + def __eq__(self, other) -> bool: + """Equality test for TwoPointXiTheta objects.""" + return ( + self.XY == other.XY + and np.array_equal(self.thetas, other.thetas) + and compare_optionals(self.xis, other.xis) + ) + + def has_data(self) -> bool: + """Return True if the TwoPointXiTheta object has a xis array.""" + return self.xis is not None diff --git a/firecrown/utils.py b/firecrown/utils.py index 0aceabad..e77ef628 100644 --- a/firecrown/utils.py +++ b/firecrown/utils.py @@ -1,7 +1,7 @@ """Some utility functions for patterns common in Firecrown.""" from __future__ import annotations -from typing import Generator +from typing import Generator, TypeVar, Type import numpy as np import numpy.typing as npt @@ -10,6 +10,23 @@ import sacc import yaml +from yaml import CLoader as Loader +from yaml import CDumper as Dumper + +ST = TypeVar("ST") # This will be used in YAMLSerializable + + +class YAMLSerializable: + """Protocol for classes that can be serialized to and from YAML.""" + + def to_yaml(self: ST) -> str: + """Return the YAML representation of the object.""" + return yaml.dump(self, Dumper=Dumper, sort_keys=False) + + @classmethod + def from_yaml(cls: Type[ST], yaml_str: str) -> ST: + """Load the object from YAML.""" + return yaml.load(yaml_str, Loader=Loader) def base_model_from_yaml(cls: type, yaml_str: str): @@ -101,3 +118,18 @@ def compare_optional_arrays(x: None | npt.NDArray, y: None | npt.NDArray) -> boo return np.array_equal(x, y) # One is None and the other is not. return False + + +def compare_optionals(x: None | object, y: None | object) -> bool: + """Compare two objects, allowing for either or both to be None. + + :param x: first object + :param y: second object + :return: whether the objects are equal + """ + if x is None and y is None: + return True + if x is not None and y is not None: + return x == y + # One is None and the other is not. + return False diff --git a/tests/conftest.py b/tests/conftest.py index 6c637bb6..ee9813e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,22 @@ from firecrown.parameters import ParamsMap from firecrown.connector.mapping import MappingCosmoSIS, mapping_builder from firecrown.modeling_tools import ModelingTools -from firecrown.metadata.two_point import TracerNames +from firecrown.metadata.two_point_types import ( + Galaxies, + measurement_is_compatible_harmonic, + measurement_is_compatible_real, +) +from firecrown.metadata.two_point import ( + TracerNames, + InferredGalaxyZDist, + Window, + TwoPointXY, + TwoPointCells, + TwoPointCWindow, + TwoPointXiTheta, +) +import firecrown.likelihood.weak_lensing as wl +import firecrown.likelihood.number_counts as nc def pytest_addoption(parser): @@ -148,9 +163,152 @@ def fixture_tools_with_vanilla_cosmology(): result.update(ParamsMap()) result.prepare(pyccl.CosmologyVanillaLCDM()) + return result + + +@pytest.fixture( + name="harmonic_bin_1", + params=[Galaxies.COUNTS, Galaxies.SHEAR_E], +) +def make_harmonic_bin_1(request) -> InferredGalaxyZDist: + """Generate an InferredGalaxyZDist object with 5 bins.""" + x = InferredGalaxyZDist( + bin_name="bin_1", + z=np.linspace(0, 1, 5), + dndz=np.array([0.1, 0.5, 0.2, 0.3, 0.4]), + measurements={request.param}, + ) + return x + + +@pytest.fixture( + name="harmonic_bin_2", + params=[Galaxies.COUNTS, Galaxies.SHEAR_E], +) +def make_harmonic_bin_2(request) -> InferredGalaxyZDist: + """Generate an InferredGalaxyZDist object with 3 bins.""" + x = InferredGalaxyZDist( + bin_name="bin_2", + z=np.linspace(0, 1, 3), + dndz=np.array([0.1, 0.5, 0.4]), + measurements={request.param}, + ) + return x + + +@pytest.fixture( + name="real_bin_1", + params=[ + Galaxies.COUNTS, + Galaxies.SHEAR_T, + Galaxies.SHEAR_MINUS, + Galaxies.SHEAR_PLUS, + ], +) +def make_real_bin_1(request) -> InferredGalaxyZDist: + """Generate an InferredGalaxyZDist object with 5 bins.""" + x = InferredGalaxyZDist( + bin_name="bin_1", + z=np.linspace(0, 1, 5), + dndz=np.array([0.1, 0.5, 0.2, 0.3, 0.4]), + measurements={request.param}, + ) + return x + + +@pytest.fixture( + name="real_bin_2", + params=[ + Galaxies.COUNTS, + Galaxies.SHEAR_T, + Galaxies.SHEAR_MINUS, + Galaxies.SHEAR_PLUS, + ], +) +def make_real_bin_2(request) -> InferredGalaxyZDist: + """Generate an InferredGalaxyZDist object with 3 bins.""" + x = InferredGalaxyZDist( + bin_name="bin_2", + z=np.linspace(0, 1, 3), + dndz=np.array([0.1, 0.5, 0.4]), + measurements={request.param}, + ) + return x + + +@pytest.fixture(name="window_1") +def make_window_1() -> Window: + """Generate a Window object with 100 ells.""" + ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) + ells_for_interpolation = np.array(np.linspace(0, 100, 100), dtype=np.int64) + weights = np.ones(400).reshape(-1, 4) + + window = Window( + ells=ells, + weights=weights, + ells_for_interpolation=ells_for_interpolation, + ) + return window + + +@pytest.fixture(name="harmonic_two_point_xy") +def make_harmonic_two_point_xy( + harmonic_bin_1: InferredGalaxyZDist, + harmonic_bin_2: InferredGalaxyZDist, +) -> TwoPointXY: + """Generate a TwoPointCWindow object with 100 ells.""" + m1 = list(harmonic_bin_1.measurements)[0] + m2 = list(harmonic_bin_2.measurements)[0] + if not measurement_is_compatible_harmonic(m1, m2): + pytest.skip("Incompatible measurements") + xy = TwoPointXY( + x=harmonic_bin_1, y=harmonic_bin_2, x_measurement=m1, y_measurement=m2 + ) + return xy + + +@pytest.fixture(name="real_two_point_xy") +def make_real_two_point_xy( + real_bin_1: InferredGalaxyZDist, + real_bin_2: InferredGalaxyZDist, +) -> TwoPointXY: + """Generate a TwoPointCWindow object with 100 ells.""" + m1 = list(real_bin_1.measurements)[0] + m2 = list(real_bin_2.measurements)[0] + if not measurement_is_compatible_real(m1, m2): + pytest.skip("Incompatible measurements") + xy = TwoPointXY(x=real_bin_1, y=real_bin_2, x_measurement=m1, y_measurement=m2) + return xy + + +@pytest.fixture(name="two_point_cwindow") +def make_two_point_cwindow( + window_1: Window, harmonic_two_point_xy: TwoPointXY +) -> TwoPointCWindow: + """Generate a TwoPointCWindow object with 100 ells.""" + two_point = TwoPointCWindow(XY=harmonic_two_point_xy, window=window_1) + return two_point + + +@pytest.fixture(name="two_point_cell") +def make_two_point_cell( + harmonic_two_point_xy: TwoPointXY, +) -> TwoPointCells: + """Generate a TwoPointCWindow object with 100 ells.""" + ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) + return TwoPointCells(ells=ells, XY=harmonic_two_point_xy) + + +@pytest.fixture(name="two_point_xi_theta") +def make_two_point_xi_theta(real_two_point_xy: TwoPointXY) -> TwoPointXiTheta: + """Generate a TwoPointCWindow object with 100 ells.""" + thetas = np.array(np.linspace(0, 100, 100), dtype=np.float64) + return TwoPointXiTheta(thetas=thetas, XY=real_two_point_xy) + @pytest.fixture(name="cluster_sacc_data") def fixture_cluster_sacc_data() -> sacc.Sacc: + """Return a Sacc object with cluster data.""" # pylint: disable=no-member cc = sacc.standard_types.cluster_counts # pylint: disable=no-member @@ -355,7 +513,7 @@ def fixture_sacc_galaxy_xis_src0_lens0(): @pytest.fixture(name="sacc_galaxy_cells") -def fixture_sacc_galaxy_cells(): +def fixture_sacc_galaxy_cells() -> tuple[sacc.Sacc, dict, dict]: """Fixture for a SACC data without window functions.""" sacc_data = sacc.Sacc() @@ -367,7 +525,7 @@ def fixture_sacc_galaxy_cells(): tracers: dict[str, tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]] = {} tracer_pairs: dict[ - TracerNames, tuple[str, npt.NDArray[np.int64], npt.NDArray[np.float64]] + tuple[TracerNames, str], tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]] ] = {} for i, mn in enumerate(src_bins_centers): @@ -385,8 +543,7 @@ def fixture_sacc_galaxy_cells(): for i, j in upper_triangle_indices(len(src_bins_centers)): Cells = np.random.normal(size=ells.shape[0]) sacc_data.add_ell_cl("galaxy_shear_cl_ee", f"src{i}", f"src{j}", ells, Cells) - tracer_pairs[TracerNames(f"src{i}", f"src{j}")] = ( - "galaxy_shear_cl_ee", + tracer_pairs[(TracerNames(f"src{i}", f"src{j}"), "galaxy_shear_cl_ee")] = ( ells, Cells, ) @@ -395,8 +552,7 @@ def fixture_sacc_galaxy_cells(): for i, j in upper_triangle_indices(len(lens_bins_centers)): Cells = np.random.normal(size=ells.shape[0]) sacc_data.add_ell_cl("galaxy_density_cl", f"lens{i}", f"lens{j}", ells, Cells) - tracer_pairs[TracerNames(f"lens{i}", f"lens{j}")] = ( - "galaxy_density_cl", + tracer_pairs[(TracerNames(f"lens{i}", f"lens{j}"), "galaxy_density_cl")] = ( ells, Cells, ) @@ -407,8 +563,9 @@ def fixture_sacc_galaxy_cells(): sacc_data.add_ell_cl( "galaxy_shearDensity_cl_e", f"src{i}", f"lens{j}", ells, Cells ) - tracer_pairs[TracerNames(f"src{i}", f"lens{j}")] = ( - "galaxy_shearDensity_cl_e", + tracer_pairs[ + (TracerNames(f"src{i}", f"lens{j}"), "galaxy_shearDensity_cl_e") + ] = ( ells, Cells, ) @@ -422,6 +579,108 @@ def fixture_sacc_galaxy_cells(): return sacc_data, tracers, tracer_pairs +@pytest.fixture(name="sacc_galaxy_cwindows") +def fixture_sacc_galaxy_cwindows(): + """Fixture for a SACC data with window functions.""" + sacc_data = sacc.Sacc() + + z = np.linspace(0, 1.0, 256) + 0.05 + ells = np.unique(np.logspace(1, 3, 10).astype(np.int64)) + nobs = len(ells) - 5 + + src_bins_centers = np.linspace(0.25, 0.75, 5) + lens_bins_centers = np.linspace(0.1, 0.6, 5) + + tracers: dict[str, tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]] = {} + tracer_pairs: dict[ + tuple[TracerNames, str], + tuple[npt.NDArray[np.int64], npt.NDArray[np.float64], sacc.BandpowerWindow], + ] = {} + + for i, mn in enumerate(src_bins_centers): + dndz = np.exp(-0.5 * (z - mn) ** 2 / 0.05 / 0.05) + sacc_data.add_tracer("NZ", f"src{i}", z, dndz) + tracers[f"src{i}"] = (z, dndz) + + for i, mn in enumerate(lens_bins_centers): + dndz = np.exp(-0.5 * (z - mn) ** 2 / 0.05 / 0.05) + sacc_data.add_tracer("NZ", f"lens{i}", z, dndz) + tracers[f"lens{i}"] = (z, dndz) + + for i, j in upper_triangle_indices(len(src_bins_centers)): + + weights = ( + np.eye(ells.shape[0], nobs, dtype=np.float64) + + np.eye(ells.shape[0], nobs, k=5, dtype=np.float64) + + np.eye(ells.shape[0], nobs, k=-5, dtype=np.float64) + ) + window = sacc.BandpowerWindow(ells, weights) + Cells = np.random.normal(size=nobs) + sacc_data.add_ell_cl( + "galaxy_shear_cl_ee", + f"src{i}", + f"src{j}", + weights.T.dot(ells), + Cells, + window=window, + ) + tracer_pairs[(TracerNames(f"src{i}", f"src{j}"), "galaxy_shear_cl_ee")] = ( + ells, + Cells, + window, + ) + + for i, j in upper_triangle_indices(len(lens_bins_centers)): + weights = ( + np.eye(ells.shape[0], nobs, dtype=np.float64) + + np.eye(ells.shape[0], nobs, k=5, dtype=np.float64) + + np.eye(ells.shape[0], nobs, k=-5, dtype=np.float64) + ) + window = sacc.BandpowerWindow(ells, weights) + Cells = np.random.normal(size=nobs) + sacc_data.add_ell_cl( + "galaxy_density_cl", + f"lens{i}", + f"lens{j}", + weights.T.dot(ells), + Cells, + window=window, + ) + tracer_pairs[(TracerNames(f"lens{i}", f"lens{j}"), "galaxy_density_cl")] = ( + ells, + Cells, + window, + ) + + for i, j in product(range(len(src_bins_centers)), range(len(lens_bins_centers))): + weights = ( + np.eye(ells.shape[0], nobs, dtype=np.float64) + + np.eye(ells.shape[0], nobs, k=5, dtype=np.float64) + + np.eye(ells.shape[0], nobs, k=-5, dtype=np.float64) + ) + window = sacc.BandpowerWindow(ells, weights) + Cells = np.random.normal(size=nobs) + sacc_data.add_ell_cl( + "galaxy_shearDensity_cl_e", + f"src{i}", + f"lens{j}", + weights.T.dot(ells), + Cells, + window=window, + ) + tracer_pairs[ + (TracerNames(f"src{i}", f"lens{j}"), "galaxy_shearDensity_cl_e") + ] = ( + ells, + Cells, + window, + ) + + sacc_data.add_covariance(np.identity(len(sacc_data)) * 0.01) + + return sacc_data, tracers, tracer_pairs + + @pytest.fixture(name="sacc_galaxy_xis") def fixture_sacc_galaxy_xis(): """Fixture for a SACC data without window functions.""" @@ -436,7 +695,7 @@ def fixture_sacc_galaxy_xis(): tracers: dict[str, tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]] = {} tracer_pairs: dict[ - TracerNames, tuple[str, npt.NDArray[np.float64], npt.NDArray[np.float64]] + tuple[TracerNames, str], tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]] ] = {} for i, mn in enumerate(src_bins_centers): @@ -454,8 +713,7 @@ def fixture_sacc_galaxy_xis(): for i, j in upper_triangle_indices(len(lens_bins_centers)): xis = np.random.normal(size=thetas.shape[0]) sacc_data.add_theta_xi("galaxy_density_xi", f"lens{i}", f"lens{j}", thetas, xis) - tracer_pairs[TracerNames(f"lens{i}", f"lens{j}")] = ( - "galaxy_density_xi", + tracer_pairs[(TracerNames(f"lens{i}", f"lens{j}"), "galaxy_density_xi")] = ( thetas, xis, ) @@ -466,8 +724,110 @@ def fixture_sacc_galaxy_xis(): sacc_data.add_theta_xi( "galaxy_shearDensity_xi_t", f"src{i}", f"lens{j}", thetas, xis ) - tracer_pairs[TracerNames(f"src{i}", f"lens{j}")] = ( - "galaxy_shearDensity_xi_t", + tracer_pairs[ + (TracerNames(f"src{i}", f"lens{j}"), "galaxy_shearDensity_xi_t") + ] = ( + thetas, + xis, + ) + dv.append(xis) + + for i, j in upper_triangle_indices(len(src_bins_centers)): + xis = np.random.normal(size=thetas.shape[0]) + sacc_data.add_theta_xi( + "galaxy_shear_xi_minus", f"src{i}", f"src{j}", thetas, xis + ) + tracer_pairs[(TracerNames(f"src{i}", f"src{j}"), "galaxy_shear_xi_minus")] = ( + thetas, + xis, + ) + dv.append(xis) + for i, j in upper_triangle_indices(len(src_bins_centers)): + xis = np.random.normal(size=thetas.shape[0]) + sacc_data.add_theta_xi( + "galaxy_shear_xi_plus", f"src{i}", f"src{j}", thetas, xis + ) + tracer_pairs[(TracerNames(f"src{i}", f"src{j}"), "galaxy_shear_xi_plus")] = ( + thetas, + xis, + ) + dv.append(xis) + + delta_v = np.concatenate(dv, axis=0) + cov = np.diag(np.ones_like(delta_v) * 0.01) + + sacc_data.add_covariance(cov) + + return sacc_data, tracers, tracer_pairs + + +@pytest.fixture(name="sacc_galaxy_xis_inverted") +def fixture_sacc_galaxy_xis_inverted(): + """Fixture for a SACC data without window functions.""" + + z = np.linspace(0, 1.0, 256) + 0.05 + thetas = np.linspace(0.0, 2.0 * np.pi, 20) + + sacc_data = sacc.Sacc() + + src_bins_centers = np.linspace(0.25, 0.75, 5) + lens_bins_centers = np.linspace(0.1, 0.6, 5) + + tracers: dict[str, tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]] = {} + tracer_pairs: dict[ + tuple[TracerNames, str], tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]] + ] = {} + + for i, mn in enumerate(src_bins_centers): + dndz = np.exp(-0.5 * (z - mn) ** 2 / 0.05 / 0.05) + sacc_data.add_tracer("NZ", f"src{i}", z, dndz) + tracers[f"src{i}"] = (z, dndz) + + for i, mn in enumerate(lens_bins_centers): + dndz = np.exp(-0.5 * (z - mn) ** 2 / 0.05 / 0.05) + sacc_data.add_tracer("NZ", f"lens{i}", z, dndz) + tracers[f"lens{i}"] = (z, dndz) + + dv = [] + + for i, j in upper_triangle_indices(len(lens_bins_centers)): + xis = np.random.normal(size=thetas.shape[0]) + sacc_data.add_theta_xi("galaxy_density_xi", f"lens{j}", f"lens{i}", thetas, xis) + tracer_pairs[(TracerNames(f"lens{j}", f"lens{i}"), "galaxy_density_xi")] = ( + thetas, + xis, + ) + dv.append(xis) + + for i, j in product(range(len(src_bins_centers)), range(len(lens_bins_centers))): + xis = np.random.normal(size=thetas.shape[0]) + sacc_data.add_theta_xi( + "galaxy_shearDensity_xi_t", f"lens{i}", f"src{j}", thetas, xis + ) + tracer_pairs[ + (TracerNames(f"lens{j}", f"src{i}"), "galaxy_shearDensity_xi_t") + ] = ( + thetas, + xis, + ) + dv.append(xis) + + for i, j in upper_triangle_indices(len(src_bins_centers)): + xis = np.random.normal(size=thetas.shape[0]) + sacc_data.add_theta_xi( + "galaxy_shear_xi_minus", f"src{j}", f"src{i}", thetas, xis + ) + tracer_pairs[(TracerNames(f"src{j}", f"src{i}"), "galaxy_shear_xi_minus")] = ( + thetas, + xis, + ) + dv.append(xis) + for i, j in upper_triangle_indices(len(src_bins_centers)): + xis = np.random.normal(size=thetas.shape[0]) + sacc_data.add_theta_xi( + "galaxy_shear_xi_plus", f"src{j}", f"src{i}", thetas, xis + ) + tracer_pairs[(TracerNames(f"src{j}", f"src{i}"), "galaxy_shear_xi_plus")] = ( thetas, xis, ) @@ -481,6 +841,26 @@ def fixture_sacc_galaxy_xis(): return sacc_data, tracers, tracer_pairs +@pytest.fixture(name="sacc_galaxy_cells_ambiguous") +def fixture_sacc_galaxy_cells_ambiguous() -> sacc.Sacc: + """Fixture for a SACC data without window functions with ambiguous tracers.""" + sacc_data = sacc.Sacc() + + z = np.linspace(0, 1.0, 256) + 0.05 + ells = np.unique(np.logspace(1, 3, 10).astype(np.int64)) + + dndz = np.exp(-0.5 * (z - 0.5) ** 2 / 0.05 / 0.05) + sacc_data.add_tracer("NZ", "bin0", z, dndz) + sacc_data.add_tracer("NZ", "bin1", z, dndz) + Cells = np.random.normal(size=ells.shape[0]) + sacc_data.add_ell_cl("galaxy_shearDensity_cl_e", "bin0", "bin1", ells, Cells) + cov = np.diag(np.zeros(len(ells)) + 0.01) + + sacc_data.add_covariance(cov) + + return sacc_data + + @pytest.fixture(name="sacc_galaxy_cells_src0_src0_no_data") def fixture_sacc_galaxy_cells_src0_src0_no_data(): """Fixture for a SACC data without window functions.""" @@ -554,3 +934,15 @@ def fixture_sacc_galaxy_cells_src0_src0_no_window(): sacc_data.add_covariance(cov) return sacc_data, z, dndz + + +@pytest.fixture(name="wl_factory") +def make_wl_factory(): + """Generate a WeakLensingFactory object.""" + return wl.WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) + + +@pytest.fixture(name="nc_factory") +def make_nc_factory(): + """Generate a NumberCountsFactory object.""" + return nc.NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) diff --git a/tests/generators/test_inferred_galaxy_zdist.py b/tests/generators/test_inferred_galaxy_zdist.py index 3956604b..abbd5659 100644 --- a/tests/generators/test_inferred_galaxy_zdist.py +++ b/tests/generators/test_inferred_galaxy_zdist.py @@ -4,6 +4,7 @@ from itertools import pairwise, product import copy +import re import pytest import numpy as np import numpy.typing as npt @@ -22,10 +23,11 @@ ZDistLSSTSRDBinCollection, LSST_Y1_LENS_BIN_COLLECTION, LSST_Y1_SOURCE_BIN_COLLECTION, - make_measurement, - make_measurement_dict, + Measurement, + make_measurements, + make_measurements_dict, ) -from firecrown.metadata.two_point import Galaxies, Clusters, CMB +from firecrown.metadata.two_point_types import Galaxies, Clusters, CMB from firecrown.utils import base_model_from_yaml, base_model_to_yaml @@ -108,7 +110,7 @@ def test_compute_one_bin_dist_fix_z( sigma_z=bins["sigma_z"], z=z_array, name="lens0_y1", - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) assert_array_equal(Pz.z, z_array) @@ -129,7 +131,7 @@ def test_compute_all_bins_dist_fix_z( sigma_z=bins["sigma_z"], z=z_array, name="lens_y1", - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) assert_array_equal(Pz.z, z_array) @@ -150,7 +152,7 @@ def test_compute_one_bin_dist_autoknot( sigma_z=bins["sigma_z"], z=z_array, name="lens0_y1", - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, use_autoknot=True, autoknots_reltol=reltol, ) @@ -173,7 +175,7 @@ def test_compute_all_bins_dist_autoknot( sigma_z=bins["sigma_z"], z=z_array, name="lens_y1", - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, use_autoknot=True, autoknots_reltol=reltol, ) @@ -188,7 +190,7 @@ def test_zdist_bin(): sigma_z = 0.01 z = LinearGrid1D(start=0.0, end=3.0, num=100) bin_name = "lens0_y1" - measurement = Galaxies.COUNTS + measurements: set[Measurement] = {Galaxies.COUNTS} zbin = ZDistLSSTSRDBin( zpl=zpl, @@ -196,17 +198,15 @@ def test_zdist_bin(): sigma_z=sigma_z, z=z, bin_name=bin_name, - measurement=measurement, + measurements=measurements, ) - print(zbin.model_dump_json(indent=2)) - assert zbin.zpl == zpl assert zbin.zpu == zpu assert zbin.sigma_z == sigma_z assert_array_equal(zbin.z.generate(), z.generate()) assert zbin.bin_name == bin_name - assert zbin.measurement == measurement + assert zbin.measurements == measurements def test_zdist_bin_generate(zdist: ZDistLSSTSRD): @@ -216,7 +216,7 @@ def test_zdist_bin_generate(zdist: ZDistLSSTSRD): sigma_z = 0.01 z = LinearGrid1D(start=0.0, end=3.0, num=100) bin_name = "lens0_y1" - measurement = Galaxies.COUNTS + measurements: set[Measurement] = {Galaxies.COUNTS} zbin = ZDistLSSTSRDBin( zpl=zpl, @@ -224,14 +224,14 @@ def test_zdist_bin_generate(zdist: ZDistLSSTSRD): sigma_z=sigma_z, z=z, bin_name=bin_name, - measurement=measurement, + measurements=measurements, ) Pz = zbin.generate(zdist) assert_array_equal(Pz.z, z.generate()) assert Pz.bin_name == bin_name - assert Pz.measurement == measurement + assert Pz.measurements == measurements def test_zdist_bin_from_bad_yaml(): @@ -260,7 +260,7 @@ def test_zdist_bin_from_yaml(): zpu = 0.2 sigma_z = 0.01 bin_name = "lens0_y1" - measurement = Galaxies.COUNTS + measurements = {Galaxies.COUNTS} bin_yaml = """ zpl: 0.1 zpu: 0.2 @@ -270,9 +270,9 @@ def test_zdist_bin_from_yaml(): end: 3.0 num: 100 bin_name: lens0_y1 - measurement: - subject: Galaxies - property: COUNTS + measurements: + - subject: Galaxies + property: COUNTS """ zbin = base_model_from_yaml(ZDistLSSTSRDBin, bin_yaml) @@ -282,7 +282,7 @@ def test_zdist_bin_from_yaml(): assert zbin.sigma_z == sigma_z assert_array_equal(zbin.z.generate(), np.linspace(0.0, 3.0, 100)) assert zbin.bin_name == bin_name - assert zbin.measurement == measurement + assert zbin.measurements == measurements def test_zdist_bin_to_yaml(): @@ -293,7 +293,7 @@ def test_zdist_bin_to_yaml(): sigma_z=0.01, z=LinearGrid1D(start=0.0, end=3.0, num=100), bin_name="lens0_y1", - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) assert isinstance(zbin, ZDistLSSTSRDBin) assert isinstance(zbin.z, LinearGrid1D) @@ -307,7 +307,7 @@ def test_zdist_bin_to_yaml(): "sigma_z": zbin.sigma_z, "z": {"start": zbin.z.start, "end": zbin.z.end, "num": zbin.z.num}, "bin_name": zbin.bin_name, - "measurement": make_measurement_dict(zbin.measurement), + "measurements": make_measurements_dict(zbin.measurements), "use_autoknot": zbin.use_autoknot, "autoknots_reltol": zbin.autoknots_reltol, "autoknots_abstol": zbin.autoknots_abstol, @@ -355,7 +355,7 @@ def test_zdist_bin_collection_generate(zdist_bins): assert not np.array_equal(Pz.z, zbin.z.generate()) assert Pz.bin_name == zbin.bin_name - assert Pz.measurement == zbin.measurement + assert Pz.measurements == zbin.measurements def test_zdist_bin_collection_to_yaml(zdist_bins): @@ -384,7 +384,7 @@ def test_zdist_bin_collection_to_yaml(zdist_bins): "sigma_z": zbin.sigma_z, "z": zbin.z.model_dump(), "bin_name": zbin.bin_name, - "measurement": make_measurement_dict(zbin.measurement), + "measurements": make_measurements_dict(zbin.measurements), "use_autoknot": zbin.use_autoknot, "autoknots_reltol": zbin.autoknots_reltol, "autoknots_abstol": zbin.autoknots_abstol, @@ -411,7 +411,7 @@ def test_zdist_bin_collection_from_yaml(zdist_bins): "sigma_z": zbin.sigma_z, "z": zbin.z.model_dump(), "bin_name": zbin.bin_name, - "measurement": make_measurement_dict(zbin.measurement), + "measurements": make_measurements_dict(zbin.measurements), "use_autoknot": zbin.use_autoknot, "autoknots_reltol": zbin.autoknots_reltol, "autoknots_abstol": zbin.autoknots_abstol, @@ -433,39 +433,41 @@ def test_zdist_bin_collection_from_yaml(zdist_bins): assert zbin.sigma_z == zbin_from_yaml.sigma_z assert_array_equal(zbin.z.generate(), zbin_from_yaml.z.generate()) assert zbin.bin_name == zbin_from_yaml.bin_name - assert zbin.measurement == zbin_from_yaml.measurement + assert zbin.measurements == zbin_from_yaml.measurements assert zbin.use_autoknot == zbin_from_yaml.use_autoknot assert zbin.autoknots_reltol == zbin_from_yaml.autoknots_reltol assert zbin.autoknots_abstol == zbin_from_yaml.autoknots_abstol def test_make_measurement_from_measurement(): - cluster_meas = Clusters.COUNTS - galaxy_meas = Galaxies.SHEAR_E - cmb_meas = CMB.CONVERGENCE - assert make_measurement(cluster_meas) == cluster_meas - assert make_measurement(galaxy_meas) == galaxy_meas - assert make_measurement(cmb_meas) == cmb_meas + cluster_meas: set[Measurement] = {Clusters.COUNTS} + galaxy_meas: set[Measurement] = {Galaxies.SHEAR_E} + cmb_meas: set[Measurement] = {CMB.CONVERGENCE} + assert make_measurements(cluster_meas) == cluster_meas + assert make_measurements(galaxy_meas) == galaxy_meas + assert make_measurements(cmb_meas) == cmb_meas def test_make_measurement_from_dictionary(): - cluster_info = {"subject": "Clusters", "property": "COUNTS"} - galaxy_info = {"subject": "Galaxies", "property": "SHEAR_E"} - cmb_info = {"subject": "CMB", "property": "CONVERGENCE"} + cluster_info = [{"subject": "Clusters", "property": "COUNTS"}] + galaxy_info = [{"subject": "Galaxies", "property": "SHEAR_E"}] + cmb_info = [{"subject": "CMB", "property": "CONVERGENCE"}] - assert make_measurement(cluster_info) == Clusters.COUNTS - assert make_measurement(galaxy_info) == Galaxies.SHEAR_E - assert make_measurement(cmb_info) == CMB.CONVERGENCE + assert make_measurements(cluster_info) == {Clusters.COUNTS} + assert make_measurements(galaxy_info) == {Galaxies.SHEAR_E} + assert make_measurements(cmb_info) == {CMB.CONVERGENCE} with pytest.raises( ValueError, match="Invalid Measurement: subject: 'frogs' is not recognized" ): - _ = make_measurement({"subject": "frogs", "property": "SHEAR_E"}) + _ = make_measurements([{"subject": "frogs", "property": "SHEAR_E"}]) with pytest.raises( ValueError, match="Invalid Measurement: dictionary does not contain 'subject'" ): - _ = make_measurement({}) + _ = make_measurements([{}]) - with pytest.raises(ValueError, match="Invalid Measurement: 3 is not a dictionary"): - _ = make_measurement(3) # type: ignore + with pytest.raises( + ValueError, match=re.escape(r"Invalid Measurement: {3} is not a dictionary") + ): + _ = make_measurements({3}) # type: ignore diff --git a/tests/likelihood/gauss_family/statistic/source/test_source.py b/tests/likelihood/gauss_family/statistic/source/test_source.py index cd20a145..bf2146a9 100644 --- a/tests/likelihood/gauss_family/statistic/source/test_source.py +++ b/tests/likelihood/gauss_family/statistic/source/test_source.py @@ -19,11 +19,7 @@ ) import firecrown.likelihood.number_counts as nc import firecrown.likelihood.weak_lensing as wl -from firecrown.metadata.two_point import ( - extract_all_tracers, - InferredGalaxyZDist, - Galaxies, -) +from firecrown.metadata.two_point import extract_all_tracers from firecrown.parameters import ParamsMap @@ -353,14 +349,10 @@ def test_number_counts_source_init_wrong_name(sacc_galaxy_cells_lens0_lens0): source.read(sacc_data) -def test_number_counts_systematic_factory(nc_sys_factory): - bin_1 = InferredGalaxyZDist( - bin_name="bin_1", - z=np.array([1.0]), - dndz=np.array([1.0]), - measurement=Galaxies.COUNTS, - ) - sys_pz_shift = nc_sys_factory.create(bin_1) +def test_number_counts_systematic_factory( + nc_sys_factory: nc.NumberCountsSystematicFactory, +): + sys_pz_shift = nc_sys_factory.create("bin_1") assert sys_pz_shift.parameter_prefix == "bin_1" @@ -404,12 +396,8 @@ def test_nc_constantmagnificationbiassystematicfactory_no_globals(): _ = factory.create_global() -def test_weak_lensing_systematic_factory(wl_sys_factory): - bin_1 = InferredGalaxyZDist( - bin_name="bin_1", - z=np.array([1.0]), - dndz=np.array([1.0]), - measurement=Galaxies.SHEAR_E, - ) - sys_pz_shift = wl_sys_factory.create(bin_1) +def test_weak_lensing_systematic_factory( + wl_sys_factory: wl.WeakLensingSystematicFactory, +): + sys_pz_shift = wl_sys_factory.create("bin_1") assert sys_pz_shift.parameter_prefix == "bin_1" diff --git a/tests/likelihood/gauss_family/statistic/test_two_point.py b/tests/likelihood/gauss_family/statistic/test_two_point.py index f530e1f0..fb4aa524 100644 --- a/tests/likelihood/gauss_family/statistic/test_two_point.py +++ b/tests/likelihood/gauss_family/statistic/test_two_point.py @@ -3,6 +3,7 @@ """ import re +from unittest.mock import MagicMock import numpy as np import pytest @@ -23,6 +24,20 @@ TracerNames, TRACER_NAMES_TOTAL, EllOrThetaConfig, + use_source_factory, + use_source_factory_metadata_only, + WeakLensingFactory, + NumberCountsFactory, +) +from firecrown.metadata.two_point_types import ( + Galaxies, + InferredGalaxyZDist, + GALAXY_LENS_TYPES, + GALAXY_SOURCE_TYPES, +) +from firecrown.metadata.two_point import ( + TwoPointCellsIndex, + TwoPointXiThetaIndex, ) @@ -385,3 +400,116 @@ def test_two_point_lens0_lens0_data_and_conf_warn(sacc_galaxy_xis_lens0_lens0): ), ): statistic.read(sacc_data) + + +def test_use_source_factory(harmonic_bin_1: InferredGalaxyZDist): + wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) + nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) + + measurement = list(harmonic_bin_1.measurements)[0] + source = use_source_factory(harmonic_bin_1, measurement, wl_factory, nc_factory) + + if measurement in GALAXY_LENS_TYPES: + assert isinstance(source, NumberCounts) + elif measurement in GALAXY_SOURCE_TYPES: + assert isinstance(source, WeakLensing) + else: + assert False, f"Unknown measurement type: {measurement}" + + +def test_use_source_factory_invalid_measurement(harmonic_bin_1: InferredGalaxyZDist): + with pytest.raises( + ValueError, + match="Measurement .* not found in inferred galaxy redshift distribution .*", + ): + use_source_factory(harmonic_bin_1, Galaxies.SHEAR_MINUS, None, None) + + +def test_use_source_factory_metadata_only_counts(): + wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) + nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) + source = use_source_factory_metadata_only( + "bin1", Galaxies.COUNTS, wl_factory=wl_factory, nc_factory=nc_factory + ) + assert isinstance(source, NumberCounts) + + +def test_use_source_factory_metadata_only_shear(): + wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) + nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) + source = use_source_factory_metadata_only( + "bin1", Galaxies.SHEAR_E, wl_factory=wl_factory, nc_factory=nc_factory + ) + assert isinstance(source, WeakLensing) + + +def test_use_source_factory_metadata_only_invalid_measurement(): + wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) + nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) + with pytest.raises(ValueError, match="Unknown measurement type encountered .*"): + use_source_factory_metadata_only( + "bin1", 120, wl_factory=wl_factory, nc_factory=nc_factory # type: ignore + ) + + +def test_two_point_wrong_type(): + with pytest.raises(ValueError, match="The SACC data type cow is not supported!"): + TwoPoint( + "cow", WeakLensing(sacc_tracer="calma"), WeakLensing(sacc_tracer="fernando") + ) + + +def test_from_metadata_harmonic_wrong_metadata(): + with pytest.raises( + ValueError, match=re.escape("Metadata of type is not supported") + ): + TwoPoint._from_metadata( # pylint: disable=protected-access + sacc_data_type="galaxy_density_xi", + source0=NumberCounts(sacc_tracer="lens_0"), + source1=NumberCounts(sacc_tracer="lens_0"), + metadata="NotAMetadata", # type: ignore + ) + + +def test_use_source_factory_metadata_only_wrong_measurement(): + unknown_type = MagicMock() + unknown_type.configure_mock(__eq__=MagicMock(return_value=False)) + + with pytest.raises(ValueError, match="Measurement .* not supported!"): + use_source_factory_metadata_only( + "bin1", unknown_type, wl_factory=None, nc_factory=None + ) + + +def test_from_metadata_only_harmonic(): + wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) + nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) + metadata: TwoPointCellsIndex = { + "data_type": "galaxy_density_xi", + "tracer_names": TracerNames("lens0", "lens0"), + "ells": np.array(np.linspace(0, 100, 100), dtype=np.int64), + } + two_point = TwoPoint.from_metadata_only_harmonic( + [metadata], + wl_factory=wl_factory, + nc_factory=nc_factory, + ).pop() + assert isinstance(two_point, TwoPoint) + assert not two_point.ready + + +def test_from_metadata_only_real(): + wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) + nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) + metadata: TwoPointXiThetaIndex = { + "data_type": "galaxy_shear_xi_plus", + "tracer_names": TracerNames("src0", "src0"), + "thetas": np.linspace(0.0, 1.0, 100), + } + two_point = TwoPoint.from_metadata_only_real( + [metadata], + wl_factory=wl_factory, + nc_factory=nc_factory, + ).pop() + assert isinstance(two_point, TwoPoint) + assert not two_point.ready diff --git a/tests/likelihood/gauss_family/test_const_gaussian.py b/tests/likelihood/gauss_family/test_const_gaussian.py index ff1275ba..6736e86d 100644 --- a/tests/likelihood/gauss_family/test_const_gaussian.py +++ b/tests/likelihood/gauss_family/test_const_gaussian.py @@ -21,6 +21,16 @@ DerivedParameterCollection, SamplerParameter, ) +from firecrown.likelihood.two_point import ( + TwoPoint, + WeakLensingFactory, + NumberCountsFactory, +) +from firecrown.metadata.two_point import ( + extract_all_data_cells, + TwoPointCellsIndex, + TracerNames, +) class StatisticWithoutIndices(TrivialStatistic): @@ -469,3 +479,61 @@ def test_access_required_parameters( likelihood = ConstGaussian(statistics=trivial_stats) params = likelihood.required_parameters().get_default_values() assert params == {"mean": 0.0} + + +def test_create_ready(sacc_galaxy_cwindows): + sacc_data, _, _ = sacc_galaxy_cwindows + _, two_point_cwindows = extract_all_data_cells(sacc_data) + wl_factory = WeakLensingFactory(global_systematics=[], per_bin_systematics=[]) + nc_factory = NumberCountsFactory(global_systematics=[], per_bin_systematics=[]) + + two_points = TwoPoint.from_metadata_harmonic( + two_point_cwindows, wl_factory, nc_factory + ) + size = np.sum([len(two_point.get_data_vector()) for two_point in two_points]) + + likelihood = ConstGaussian.create_ready(two_points, np.diag(np.ones(size))) + assert likelihood is not None + assert isinstance(likelihood, ConstGaussian) + + +def test_create_ready_wrong_size(sacc_galaxy_cwindows): + sacc_data, _, _ = sacc_galaxy_cwindows + _, two_point_cwindows = extract_all_data_cells(sacc_data) + wl_factory = WeakLensingFactory(global_systematics=[], per_bin_systematics=[]) + nc_factory = NumberCountsFactory(global_systematics=[], per_bin_systematics=[]) + + two_points = TwoPoint.from_metadata_harmonic( + two_point_cwindows, wl_factory, nc_factory + ) + size = np.sum([len(two_point.get_data_vector()) for two_point in two_points]) + + with pytest.raises( + ValueError, + match=re.escape( + f"The covariance matrix has shape (3, 3), " + f"but the expected shape is at least ({size}, {size})." + ), + ): + ConstGaussian.create_ready(two_points, np.diag([1.0, 2.0, 3.0])) + + +def test_create_ready_not_ready(): + wl_factory = WeakLensingFactory(global_systematics=[], per_bin_systematics=[]) + nc_factory = NumberCountsFactory(global_systematics=[], per_bin_systematics=[]) + + metadata: TwoPointCellsIndex = { + "data_type": "galaxy_density_xi", + "tracer_names": TracerNames("lens0", "lens0"), + "ells": np.array(np.linspace(0, 100, 100), dtype=np.int64), + } + + two_points = TwoPoint.from_metadata_only_harmonic( + [metadata], wl_factory, nc_factory + ) + + with pytest.raises( + RuntimeError, + match="The statistic .* is not ready to be used.", + ): + ConstGaussian.create_ready(two_points, np.diag(np.ones(11))) diff --git a/tests/metadata/sacc_name_mapping.py b/tests/metadata/sacc_name_mapping.py index 491efaca..bf4059ec 100644 --- a/tests/metadata/sacc_name_mapping.py +++ b/tests/metadata/sacc_name_mapping.py @@ -1,6 +1,6 @@ """Module containing testing data for translation of SACC names.""" -from firecrown.metadata.two_point import ( +from firecrown.metadata.two_point_types import ( Clusters, CMB, Galaxies, diff --git a/tests/metadata/test_metadata_two_point.py b/tests/metadata/test_metadata_two_point.py index 447b9bdc..d0357f0c 100644 --- a/tests/metadata/test_metadata_two_point.py +++ b/tests/metadata/test_metadata_two_point.py @@ -2,262 +2,46 @@ Tests for the module firecrown.metadata.two_point """ -from itertools import product, chain -from unittest.mock import MagicMock import pytest import numpy as np from numpy.testing import assert_array_equal - -import sacc_name_mapping as snm -from firecrown.metadata.two_point import ( +from firecrown.metadata.two_point_types import ( ALL_MEASUREMENTS, + type_to_sacc_string_harmonic as harmonic, + type_to_sacc_string_real as real, +) + +from firecrown.metadata.two_point_types import ( Clusters, CMB, - compare_enums, Galaxies, InferredGalaxyZDist, - measurement_is_compatible as is_compatible, - measurement_supports_harmonic as supports_harmonic, - measurement_supports_real as supports_real, TracerNames, TwoPointCells, TwoPointCWindow, TwoPointXY, TwoPointXiTheta, - type_to_sacc_string_harmonic as harmonic, - type_to_sacc_string_real as real, + TwoPointMeasurement, Window, ) + from firecrown.likelihood.source import SourceGalaxy -import firecrown.likelihood.weak_lensing as wl -import firecrown.likelihood.number_counts as nc from firecrown.likelihood.two_point import TwoPoint -@pytest.fixture( - name="harmonic_bin_1", - params=[Galaxies.COUNTS, Galaxies.SHEAR_E], -) -def make_harmonic_bin_1(request) -> InferredGalaxyZDist: - """Generate an InferredGalaxyZDist object with 5 bins.""" - x = InferredGalaxyZDist( - bin_name="bin_1", - z=np.linspace(0, 1, 5), - dndz=np.array([0.1, 0.5, 0.2, 0.3, 0.4]), - measurement=request.param, - ) - return x - - -@pytest.fixture( - name="harmonic_bin_2", - params=[Galaxies.COUNTS, Galaxies.SHEAR_E], -) -def make_harmonic_bin_2(request) -> InferredGalaxyZDist: - """Generate an InferredGalaxyZDist object with 3 bins.""" - x = InferredGalaxyZDist( - bin_name="bin_2", - z=np.linspace(0, 1, 3), - dndz=np.array([0.1, 0.5, 0.4]), - measurement=request.param, - ) - return x - - -@pytest.fixture( - name="real_bin_1", - params=[Galaxies.COUNTS, Galaxies.SHEAR_T], -) -def make_real_bin_1(request) -> InferredGalaxyZDist: - """Generate an InferredGalaxyZDist object with 5 bins.""" - x = InferredGalaxyZDist( - bin_name="bin_1", - z=np.linspace(0, 1, 5), - dndz=np.array([0.1, 0.5, 0.2, 0.3, 0.4]), - measurement=request.param, - ) - return x - - -@pytest.fixture( - name="real_bin_2", - params=[Galaxies.COUNTS, Galaxies.SHEAR_T], -) -def make_real_bin_2(request) -> InferredGalaxyZDist: - """Generate an InferredGalaxyZDist object with 3 bins.""" - x = InferredGalaxyZDist( - bin_name="bin_2", - z=np.linspace(0, 1, 3), - dndz=np.array([0.1, 0.5, 0.4]), - measurement=request.param, - ) - return x - - -@pytest.fixture(name="window_1") -def make_window_1() -> Window: - """Generate a Window object with 100 ells.""" - ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - ells_for_interpolation = np.array(np.linspace(0, 100, 100), dtype=np.int64) - weights = np.ones(400).reshape(-1, 4) - - window = Window( - ells=ells, - weights=weights, - ells_for_interpolation=ells_for_interpolation, - ) - return window - - -@pytest.fixture(name="two_point_cwindow_1") -def make_two_point_cwindow_1( - window_1: Window, - harmonic_bin_1: InferredGalaxyZDist, - harmonic_bin_2: InferredGalaxyZDist, -) -> TwoPointCWindow: - """Generate a TwoPointCWindow object with 100 ells.""" - xy = TwoPointXY(x=harmonic_bin_1, y=harmonic_bin_2) - two_point = TwoPointCWindow(XY=xy, window=window_1) - return two_point - - -@pytest.fixture(name="wl_factory") -def make_wl_factory(): - """Generate a WeakLensingFactory object.""" - return wl.WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) - - -@pytest.fixture(name="nc_factory") -def make_nc_factory(): - """Generate a NumberCountsFactory object.""" - return nc.NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) - - -def test_order_enums(): - assert compare_enums(CMB.CONVERGENCE, Clusters.COUNTS) < 0 - assert compare_enums(Clusters.COUNTS, CMB.CONVERGENCE) > 0 - - assert compare_enums(CMB.CONVERGENCE, Galaxies.COUNTS) < 0 - assert compare_enums(Galaxies.COUNTS, CMB.CONVERGENCE) > 0 - - assert compare_enums(Galaxies.SHEAR_E, Galaxies.SHEAR_T) < 0 - assert compare_enums(Galaxies.SHEAR_E, Galaxies.COUNTS) < 0 - assert compare_enums(Galaxies.SHEAR_T, Galaxies.COUNTS) < 0 - - assert compare_enums(Galaxies.COUNTS, Galaxies.SHEAR_E) > 0 - - for enumerand in ALL_MEASUREMENTS: - assert compare_enums(enumerand, enumerand) == 0 - - -def test_enumeration_equality_galaxy(): - for e1, e2 in product(Galaxies, chain(CMB, Clusters)): - assert e1 != e2 - - -def test_enumeration_equality_cmb(): - for e1, e2 in product(CMB, chain(Galaxies, Clusters)): - assert e1 != e2 - - -def test_enumeration_equality_cluster(): - for e1, e2 in product(Clusters, chain(CMB, Galaxies)): - assert e1 != e2 - - -def test_exact_matches(): - for sacc_name, space, (enum_1, enum_2) in snm.mappings: - if space == "ell": - assert harmonic(enum_1, enum_2) == sacc_name - elif space == "theta": - assert real(enum_1, enum_2) == sacc_name - else: - raise ValueError(f"Illegal 'space' value {space} in testing data") - - -def test_translation_invariants(): - for a, b in product(ALL_MEASUREMENTS, ALL_MEASUREMENTS): - assert isinstance(a, (Galaxies, CMB, Clusters)) - assert isinstance(b, (Galaxies, CMB, Clusters)) - if supports_real(a) and supports_real(b): - assert real(a, b) == real(b, a) - if supports_harmonic(a) and supports_harmonic(b): - assert harmonic(a, b) == harmonic(b, a) - if ( - supports_harmonic(a) - and supports_harmonic(b) - and supports_real(a) - and supports_real(b) - ): - assert harmonic(a, b) != real(a, b) - - -def test_unsupported_type_galaxy(): - unknown_type = MagicMock() - unknown_type.configure_mock(__eq__=MagicMock(return_value=False)) - - with pytest.raises(ValueError, match="Untranslated Galaxy Measurement encountered"): - Galaxies.sacc_measurement_name(unknown_type) - - with pytest.raises(ValueError, match="Untranslated Galaxy Measurement encountered"): - Galaxies.polarization(unknown_type) - - -def test_unsupported_type_cmb(): - unknown_type = MagicMock() - unknown_type.configure_mock(__eq__=MagicMock(return_value=False)) - - with pytest.raises(ValueError, match="Untranslated CMBMeasurement encountered"): - CMB.sacc_measurement_name(unknown_type) - - with pytest.raises(ValueError, match="Untranslated CMBMeasurement encountered"): - CMB.polarization(unknown_type) - - -def test_unsupported_type_cluster(): - unknown_type = MagicMock() - unknown_type.configure_mock(__eq__=MagicMock(return_value=False)) - - with pytest.raises(ValueError, match="Untranslated ClusterMeasurement encountered"): - Clusters.sacc_measurement_name(unknown_type) - - with pytest.raises(ValueError, match="Untranslated ClusterMeasurement encountered"): - Clusters.polarization(unknown_type) - - -def test_type_hashs(): - for e1, e2 in product(ALL_MEASUREMENTS, ALL_MEASUREMENTS): - if e1 == e2: - assert hash(e1) == hash(e2) - else: - assert hash(e1) != hash(e2) - - -def test_measurement_is_compatible(): - for a, b in product(ALL_MEASUREMENTS, ALL_MEASUREMENTS): - assert isinstance(a, (Galaxies, CMB, Clusters)) - assert isinstance(b, (Galaxies, CMB, Clusters)) - if (supports_real(a) and supports_real(b)) or ( - supports_harmonic(a) and supports_harmonic(b) - ): - assert is_compatible(a, b) - else: - assert not is_compatible(a, b) - - def test_inferred_galaxy_z_dist(): z_dist = InferredGalaxyZDist( bin_name="bname1", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) assert z_dist.bin_name == "bname1" assert z_dist.z[0] == 0 assert z_dist.z[-1] == 1 assert z_dist.dndz[0] == 1 assert z_dist.dndz[-1] == 1 - assert z_dist.measurement == Galaxies.COUNTS + assert z_dist.measurements == {Galaxies.COUNTS} def test_inferred_galaxy_z_dist_bad_shape(): @@ -268,7 +52,7 @@ def test_inferred_galaxy_z_dist_bad_shape(): bin_name="bname1", z=np.linspace(0, 1, 100), dndz=np.ones(101), - measurement=Clusters.COUNTS, + measurements={Clusters.COUNTS}, ) @@ -278,7 +62,7 @@ def test_inferred_galaxy_z_dist_bad_type(): bin_name="bname1", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=0, # type: ignore + measurements={0}, # type: ignore ) @@ -288,7 +72,7 @@ def test_inferred_galaxy_z_dist_bad_name(): bin_name="", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) @@ -297,17 +81,65 @@ def test_two_point_xy_gal_gal(): bin_name="bname1", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) y = InferredGalaxyZDist( bin_name="bname2", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, + ) + xy = TwoPointXY( + x=x, y=y, x_measurement=Galaxies.COUNTS, y_measurement=Galaxies.COUNTS ) - xy = TwoPointXY(x=x, y=y) assert xy.x == x assert xy.y == y + assert xy.x_measurement == Galaxies.COUNTS + assert xy.y_measurement == Galaxies.COUNTS + + +def test_two_point_xy_gal_gal_invalid_x_measurement(): + x = InferredGalaxyZDist( + bin_name="bname1", + z=np.linspace(0, 1, 100), + dndz=np.ones(100), + measurements={Galaxies.SHEAR_E}, + ) + y = InferredGalaxyZDist( + bin_name="bname2", + z=np.linspace(0, 1, 100), + dndz=np.ones(100), + measurements={Galaxies.COUNTS}, + ) + with pytest.raises( + ValueError, + match="Measurement .* not in the measurements of bname1.", + ): + TwoPointXY( + x=x, y=y, x_measurement=Galaxies.COUNTS, y_measurement=Galaxies.COUNTS + ) + + +def test_two_point_xy_gal_gal_invalid_y_measurement(): + x = InferredGalaxyZDist( + bin_name="bname1", + z=np.linspace(0, 1, 100), + dndz=np.ones(100), + measurements={Galaxies.COUNTS}, + ) + y = InferredGalaxyZDist( + bin_name="bname2", + z=np.linspace(0, 1, 100), + dndz=np.ones(100), + measurements={Galaxies.SHEAR_E}, + ) + with pytest.raises( + ValueError, + match="Measurement .* not in the measurements of bname2.", + ): + TwoPointXY( + x=x, y=y, x_measurement=Galaxies.COUNTS, y_measurement=Galaxies.COUNTS + ) def test_two_point_xy_cmb_gal(): @@ -315,17 +147,21 @@ def test_two_point_xy_cmb_gal(): bin_name="bname1", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) y = InferredGalaxyZDist( bin_name="bname2", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=CMB.CONVERGENCE, + measurements={CMB.CONVERGENCE}, + ) + xy = TwoPointXY( + x=x, y=y, x_measurement=Galaxies.COUNTS, y_measurement=CMB.CONVERGENCE ) - xy = TwoPointXY(x=x, y=y) assert xy.x == x assert xy.y == y + assert xy.x_measurement == Galaxies.COUNTS + assert xy.y_measurement == CMB.CONVERGENCE def test_two_point_xy_invalid(): @@ -333,19 +169,21 @@ def test_two_point_xy_invalid(): bin_name="bname1", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.SHEAR_E, + measurements={Galaxies.SHEAR_E}, ) y = InferredGalaxyZDist( bin_name="bname2", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.SHEAR_T, + measurements={Galaxies.SHEAR_T}, ) with pytest.raises( ValueError, match=("Measurements .* and .* are not compatible."), ): - TwoPointXY(x=x, y=y) + TwoPointXY( + x=x, y=y, x_measurement=Galaxies.SHEAR_E, y_measurement=Galaxies.SHEAR_T + ) def test_two_point_cells(): @@ -353,21 +191,23 @@ def test_two_point_cells(): bin_name="bname1", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) y = InferredGalaxyZDist( bin_name="bname2", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, + ) + xy = TwoPointXY( + x=x, y=y, x_measurement=Galaxies.COUNTS, y_measurement=Galaxies.COUNTS ) - xy = TwoPointXY(x=x, y=y) ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) cells = TwoPointCells(ells=ells, XY=xy) assert_array_equal(cells.ells, ells) assert cells.XY == xy - assert cells.get_sacc_name() == harmonic(x.measurement, y.measurement) + assert cells.get_sacc_name() == harmonic(xy.x_measurement, xy.y_measurement) def test_two_point_cells_invalid_ells(): @@ -375,15 +215,17 @@ def test_two_point_cells_invalid_ells(): bin_name="bname1", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) y = InferredGalaxyZDist( bin_name="bname2", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, + ) + xy = TwoPointXY( + x=x, y=y, x_measurement=Galaxies.COUNTS, y_measurement=Galaxies.COUNTS ) - xy = TwoPointXY(x=x, y=y) ells = np.array(np.linspace(0, 100), dtype=np.int64).reshape(-1, 10) with pytest.raises( ValueError, @@ -397,15 +239,17 @@ def test_two_point_cells_invalid_type(): bin_name="bname1", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) y = InferredGalaxyZDist( bin_name="bname2", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.SHEAR_T, + measurements={Galaxies.SHEAR_T}, + ) + xy = TwoPointXY( + x=x, y=y, x_measurement=Galaxies.COUNTS, y_measurement=Galaxies.SHEAR_T ) - xy = TwoPointXY(x=x, y=y) ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) with pytest.raises( ValueError, @@ -529,7 +373,7 @@ def test_two_point_window_invalid_weights_shape(): ) -def test_two_point_two_point_cwindow(): +def test_two_point_two_point_cwindow(harmonic_two_point_xy: TwoPointXY): ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) ells_for_interpolation = np.array(np.linspace(0, 100, 100), dtype=np.int64) weights = np.ones(400).reshape(-1, 4) @@ -540,25 +384,64 @@ def test_two_point_two_point_cwindow(): ells_for_interpolation=ells_for_interpolation, ) - x = InferredGalaxyZDist( - bin_name="bname1", - z=np.linspace(0, 1, 100), - dndz=np.ones(100), - measurement=Galaxies.COUNTS, + ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) + two_point = TwoPointCWindow(XY=harmonic_two_point_xy, window=window) + + assert two_point.window == window + assert two_point.XY == harmonic_two_point_xy + assert two_point.get_sacc_name() == harmonic( + harmonic_two_point_xy.x_measurement, harmonic_two_point_xy.y_measurement ) - y = InferredGalaxyZDist( - bin_name="bname2", - z=np.linspace(0, 1, 100), - dndz=np.ones(100), - measurement=Galaxies.COUNTS, + + +def test_two_point_two_point_cwindow_wrong_data_shape( + harmonic_two_point_xy: TwoPointXY, +): + ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) + ells_for_interpolation = np.array(np.linspace(0, 100, 100), dtype=np.int64) + weights = np.ones(400).reshape(-1, 4) + + window = Window( + ells=ells, + weights=weights, + ells_for_interpolation=ells_for_interpolation, ) - xy = TwoPointXY(x=x, y=y) + + data = np.zeros(100) + 1.1 + indices = np.arange(100) + covariance_name = "cov" + data = data.reshape(-1, 10) + with pytest.raises( + ValueError, + match="Data should be a 1D array.", + ): + TwoPointCWindow( + XY=harmonic_two_point_xy, + window=window, + Cell=TwoPointMeasurement( + data=data, + indices=indices, + covariance_name=covariance_name, + ), + ) + + +def test_two_point_two_point_cwindow_stringify(harmonic_two_point_xy: TwoPointXY): ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - two_point = TwoPointCWindow(XY=xy, window=window) + ells_for_interpolation = np.array(np.linspace(0, 100, 100), dtype=np.int64) + weights = np.ones(400).reshape(-1, 4) - assert two_point.window == window - assert two_point.XY == xy - assert two_point.get_sacc_name() == harmonic(x.measurement, y.measurement) + window = Window( + ells=ells, + weights=weights, + ells_for_interpolation=ells_for_interpolation, + ) + + two_point = TwoPointCWindow(XY=harmonic_two_point_xy, window=window) + + assert ( + str(two_point) == f"{str(harmonic_two_point_xy)}[{two_point.get_sacc_name()}]" + ) def test_two_point_two_point_cwindow_invalid(): @@ -576,15 +459,17 @@ def test_two_point_two_point_cwindow_invalid(): bin_name="bname1", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) y = InferredGalaxyZDist( bin_name="bname2", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.SHEAR_T, + measurements={Galaxies.SHEAR_T}, + ) + xy = TwoPointXY( + x=x, y=y, x_measurement=Galaxies.COUNTS, y_measurement=Galaxies.SHEAR_T ) - xy = TwoPointXY(x=x, y=y) ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) with pytest.raises( ValueError, @@ -598,15 +483,17 @@ def test_two_point_two_point_cwindow_invalid_window(): bin_name="bname1", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) y = InferredGalaxyZDist( bin_name="bname2", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.SHEAR_T, + measurements={Galaxies.SHEAR_T}, + ) + xy = TwoPointXY( + x=x, y=y, x_measurement=Galaxies.COUNTS, y_measurement=Galaxies.SHEAR_T ) - xy = TwoPointXY(x=x, y=y) with pytest.raises( ValueError, match="Window should be a Window object.", @@ -615,26 +502,27 @@ def test_two_point_two_point_cwindow_invalid_window(): def test_two_point_xi_theta(): - x = InferredGalaxyZDist( bin_name="bname1", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) y = InferredGalaxyZDist( bin_name="bname2", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, + ) + xy = TwoPointXY( + x=x, y=y, x_measurement=Galaxies.COUNTS, y_measurement=Galaxies.COUNTS ) - xy = TwoPointXY(x=x, y=y) theta = np.array(np.linspace(0, 100, 100)) two_point = TwoPointXiTheta(XY=xy, thetas=theta) assert_array_equal(two_point.thetas, theta) assert two_point.XY == xy - assert two_point.get_sacc_name() == real(x.measurement, y.measurement) + assert two_point.get_sacc_name() == real(xy.x_measurement, xy.y_measurement) def test_two_point_xi_theta_invalid(): @@ -642,15 +530,17 @@ def test_two_point_xi_theta_invalid(): bin_name="bname1", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) y = InferredGalaxyZDist( bin_name="bname2", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.SHEAR_E, + measurements={Galaxies.SHEAR_E}, + ) + xy = TwoPointXY( + x=x, y=y, x_measurement=Galaxies.COUNTS, y_measurement=Galaxies.SHEAR_E ) - xy = TwoPointXY(x=x, y=y) theta = np.array(np.linspace(0, 100, 100)) with pytest.raises( ValueError, @@ -695,26 +585,44 @@ def test_inferred_galaxy_zdist_serialization(harmonic_bin_1: InferredGalaxyZDist assert harmonic_bin_1 == recovered -def test_two_point_xy_serialization( - harmonic_bin_1: InferredGalaxyZDist, harmonic_bin_2: InferredGalaxyZDist -): - xy = TwoPointXY(x=harmonic_bin_1, y=harmonic_bin_2) - s = xy.to_yaml() +def test_two_point_xy_str(harmonic_two_point_xy: TwoPointXY): + assert str(harmonic_two_point_xy) == ( + f"({harmonic_two_point_xy.x.bin_name}, " f"{harmonic_two_point_xy.y.bin_name})" + ) + + +def test_two_point_xy_serialization(harmonic_two_point_xy: TwoPointXY): + s = harmonic_two_point_xy.to_yaml() # Take a look at how hideous the generated string # is. recovered = TwoPointXY.from_yaml(s) - assert xy == recovered + assert harmonic_two_point_xy == recovered + assert str(harmonic_two_point_xy) == str(recovered) -def test_two_point_cells_serialization( - harmonic_bin_1: InferredGalaxyZDist, harmonic_bin_2: InferredGalaxyZDist -): +def test_two_point_cells_str(harmonic_two_point_xy: TwoPointXY): ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - xy = TwoPointXY(x=harmonic_bin_1, y=harmonic_bin_2) - cells = TwoPointCells(ells=ells, XY=xy) + cells = TwoPointCells(ells=ells, XY=harmonic_two_point_xy) + assert str(cells) == f"{str(harmonic_two_point_xy)}[{cells.get_sacc_name()}]" + + +def test_two_point_cells_serialization(harmonic_two_point_xy: TwoPointXY): + ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) + cells = TwoPointCells(ells=ells, XY=harmonic_two_point_xy) s = cells.to_yaml() recovered = TwoPointCells.from_yaml(s) assert cells == recovered + assert str(harmonic_two_point_xy) == str(recovered.XY) + assert str(cells) == str(recovered) + + +def test_two_point_cells_ells_wrong_shape(harmonic_two_point_xy: TwoPointXY): + ells = np.array(np.linspace(0, 100), dtype=np.int64).reshape(-1, 10) + with pytest.raises( + ValueError, + match="Ells should be a 1D array.", + ): + TwoPointCells(ells=ells, XY=harmonic_two_point_xy) def test_window_serialization(window_1: Window): @@ -723,30 +631,35 @@ def test_window_serialization(window_1: Window): assert window_1 == recovered -def test_two_point_cwindow_serialization(two_point_cwindow_1: TwoPointCWindow): - s = two_point_cwindow_1.to_yaml() +def test_two_point_cwindow_serialization(two_point_cwindow: TwoPointCWindow): + s = two_point_cwindow.to_yaml() recovered = TwoPointCWindow.from_yaml(s) - assert two_point_cwindow_1 == recovered + assert two_point_cwindow == recovered -def test_two_point_xi_theta_serialization( - real_bin_1: InferredGalaxyZDist, real_bin_2: InferredGalaxyZDist -): - xy = TwoPointXY(x=real_bin_1, y=real_bin_2) +def test_two_point_xi_theta_serialization(real_two_point_xy: TwoPointXY): theta = np.array(np.linspace(0, 10, 10)) - xi_theta = TwoPointXiTheta(XY=xy, thetas=theta) + xi_theta = TwoPointXiTheta(XY=real_two_point_xy, thetas=theta) s = xi_theta.to_yaml() recovered = TwoPointXiTheta.from_yaml(s) assert xi_theta == recovered + assert str(real_two_point_xy) == str(recovered.XY) + assert str(xi_theta) == str(recovered) -def test_two_point_from_metadata_cells( - harmonic_bin_1, harmonic_bin_2, wl_factory, nc_factory -): +def test_two_point_xi_theta_wrong_shape(real_two_point_xy: TwoPointXY): + theta = np.array(np.linspace(0, 10), dtype=np.float64).reshape(-1, 10) + with pytest.raises( + ValueError, + match="Thetas should be a 1D array.", + ): + TwoPointXiTheta(XY=real_two_point_xy, thetas=theta) + + +def test_two_point_from_metadata_cells(harmonic_two_point_xy, wl_factory, nc_factory): ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - xy = TwoPointXY(x=harmonic_bin_1, y=harmonic_bin_2) - cells = TwoPointCells(ells=ells, XY=xy) - two_point = TwoPoint.from_metadata_cells([cells], wl_factory, nc_factory).pop() + cells = TwoPointCells(ells=ells, XY=harmonic_two_point_xy) + two_point = TwoPoint.from_metadata_harmonic([cells], wl_factory, nc_factory).pop() assert two_point is not None assert isinstance(two_point, TwoPoint) @@ -755,11 +668,51 @@ def test_two_point_from_metadata_cells( assert isinstance(two_point.source0, SourceGalaxy) assert isinstance(two_point.source1, SourceGalaxy) - assert_array_equal(two_point.source0.tracer_args.z, harmonic_bin_1.z) - assert_array_equal(two_point.source0.tracer_args.z, harmonic_bin_1.z) + assert_array_equal(two_point.source0.tracer_args.z, harmonic_two_point_xy.x.z) + assert_array_equal(two_point.source1.tracer_args.z, harmonic_two_point_xy.y.z) + + assert_array_equal(two_point.source0.tracer_args.dndz, harmonic_two_point_xy.x.dndz) + assert_array_equal(two_point.source1.tracer_args.dndz, harmonic_two_point_xy.y.dndz) + + +def test_two_point_from_metadata_cwindow(two_point_cwindow, wl_factory, nc_factory): + two_point = TwoPoint.from_metadata_harmonic( + [two_point_cwindow], wl_factory, nc_factory + ).pop() + + assert two_point is not None + assert isinstance(two_point, TwoPoint) + assert two_point.sacc_data_type == two_point_cwindow.get_sacc_name() + + assert isinstance(two_point.source0, SourceGalaxy) + assert isinstance(two_point.source1, SourceGalaxy) + + assert_array_equal(two_point.source0.tracer_args.z, two_point_cwindow.XY.x.z) + assert_array_equal(two_point.source1.tracer_args.z, two_point_cwindow.XY.y.z) - assert_array_equal(two_point.source0.tracer_args.dndz, harmonic_bin_1.dndz) - assert_array_equal(two_point.source1.tracer_args.dndz, harmonic_bin_2.dndz) + assert_array_equal(two_point.source0.tracer_args.dndz, two_point_cwindow.XY.x.dndz) + assert_array_equal(two_point.source1.tracer_args.dndz, two_point_cwindow.XY.y.dndz) + + +def test_two_point_from_metadata_xi_theta(real_two_point_xy, wl_factory, nc_factory): + theta = np.array(np.linspace(0, 100, 100)) + xi_theta = TwoPointXiTheta(XY=real_two_point_xy, thetas=theta) + if xi_theta.get_sacc_name() == "galaxy_shear_xi_tt": + return + two_point = TwoPoint.from_metadata_real([xi_theta], wl_factory, nc_factory).pop() + + assert two_point is not None + assert isinstance(two_point, TwoPoint) + assert two_point.sacc_data_type == xi_theta.get_sacc_name() + + assert isinstance(two_point.source0, SourceGalaxy) + assert isinstance(two_point.source1, SourceGalaxy) + + assert_array_equal(two_point.source0.tracer_args.z, real_two_point_xy.x.z) + assert_array_equal(two_point.source1.tracer_args.z, real_two_point_xy.y.z) + + assert_array_equal(two_point.source0.tracer_args.dndz, real_two_point_xy.x.dndz) + assert_array_equal(two_point.source1.tracer_args.dndz, real_two_point_xy.y.dndz) def test_two_point_from_metadata_cells_unsupported_type(wl_factory, nc_factory): @@ -768,18 +721,20 @@ def test_two_point_from_metadata_cells_unsupported_type(wl_factory, nc_factory): bin_name="bname1", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=CMB.CONVERGENCE, + measurements={CMB.CONVERGENCE}, ) y = InferredGalaxyZDist( bin_name="bname2", z=np.linspace(0, 1, 100), dndz=np.ones(100), - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, + ) + xy = TwoPointXY( + x=x, y=y, x_measurement=CMB.CONVERGENCE, y_measurement=Galaxies.COUNTS ) - xy = TwoPointXY(x=x, y=y) cells = TwoPointCells(ells=ells, XY=xy) with pytest.raises( ValueError, match="Measurement .* not supported!", ): - TwoPoint.from_metadata_cells([cells], wl_factory, nc_factory) + TwoPoint.from_metadata_harmonic([cells], wl_factory, nc_factory) diff --git a/tests/metadata/test_metadata_two_point_measurement.py b/tests/metadata/test_metadata_two_point_measurement.py new file mode 100644 index 00000000..b64b8d1f --- /dev/null +++ b/tests/metadata/test_metadata_two_point_measurement.py @@ -0,0 +1,217 @@ +""" +Tests for the module firecrown.metadata.two_point +""" + +import pytest +import numpy as np +from numpy.testing import assert_array_equal + +from firecrown.metadata.two_point_types import ( + type_to_sacc_string_harmonic as harmonic, + type_to_sacc_string_real as real, +) +from firecrown.metadata.two_point import ( + TwoPointCells, + TwoPointCWindow, + TwoPointXiTheta, + TwoPointXY, + TwoPointMeasurement, + Window, +) + + +def test_two_point_cells_with_data(harmonic_two_point_xy: TwoPointXY): + ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) + data = np.ones_like(ells) * 1.1 + indices = np.arange(100) + covariance_name = "cov" + measure = TwoPointMeasurement( + data=data, indices=indices, covariance_name=covariance_name + ) + + cells = TwoPointCells(ells=ells, XY=harmonic_two_point_xy, Cell=measure) + + assert cells.ells[0] == 0 + assert cells.ells[-1] == 100 + assert cells.Cell is not None + assert cells.has_data() + assert_array_equal(cells.Cell.data, data) + assert_array_equal(cells.Cell.indices, indices) + assert cells.Cell.covariance_name == covariance_name + + +def test_two_point_two_point_cwindow_with_data(harmonic_two_point_xy: TwoPointXY): + ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) + ells_for_interpolation = np.array(np.linspace(0, 100, 100), dtype=np.int64) + weights = np.ones(400).reshape(-1, 4) + + window = Window( + ells=ells, + weights=weights, + ells_for_interpolation=ells_for_interpolation, + ) + + ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) + data = np.zeros(4) + 1.1 + indices = np.arange(4) + covariance_name = "cov" + measure = TwoPointMeasurement( + data=data, indices=indices, covariance_name=covariance_name + ) + + two_point = TwoPointCWindow(XY=harmonic_two_point_xy, window=window, Cell=measure) + + assert two_point.window == window + assert two_point.XY == harmonic_two_point_xy + assert two_point.get_sacc_name() == harmonic( + harmonic_two_point_xy.x_measurement, harmonic_two_point_xy.y_measurement + ) + assert two_point.has_data() + assert two_point.Cell is not None + assert_array_equal(two_point.Cell.data, data) + assert_array_equal(two_point.Cell.indices, indices) + assert two_point.Cell.covariance_name == covariance_name + + +def test_two_point_xi_theta_with_data(real_two_point_xy: TwoPointXY): + data = np.zeros(100) + 1.1 + indices = np.arange(100) + covariance_name = "cov" + measure = TwoPointMeasurement( + data=data, indices=indices, covariance_name=covariance_name + ) + thetas = np.linspace(0.0, 1.0, 100) + + xi_theta = TwoPointXiTheta(XY=real_two_point_xy, thetas=thetas, xis=measure) + + assert xi_theta.XY == real_two_point_xy + assert xi_theta.get_sacc_name() == real( + real_two_point_xy.x_measurement, real_two_point_xy.y_measurement + ) + assert xi_theta.has_data() + assert xi_theta.xis is not None + assert_array_equal(xi_theta.xis.data, data) + assert_array_equal(xi_theta.xis.indices, indices) + assert xi_theta.xis.covariance_name == covariance_name + + +def test_two_point_cells_with_invalid_data_size(harmonic_two_point_xy: TwoPointXY): + ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) + data = np.zeros(101) + 1.1 + indices = np.arange(101) + covariance_name = "cov" + measure = TwoPointMeasurement( + data=data, indices=indices, covariance_name=covariance_name + ) + + with pytest.raises( + ValueError, + match="Cell should have the same shape as ells.", + ): + TwoPointCells(ells=ells, XY=harmonic_two_point_xy, Cell=measure) + + +def test_two_point_cwindow_with_invalid_data_size(harmonic_two_point_xy: TwoPointXY): + ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) + ells_for_interpolation = np.array(np.linspace(0, 100, 100), dtype=np.int64) + weights = np.ones(400).reshape(-1, 4) + + window = Window( + ells=ells, + weights=weights, + ells_for_interpolation=ells_for_interpolation, + ) + ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) + data = np.zeros(5) + 1.1 + indices = np.arange(5) + covariance_name = "cov" + measure = TwoPointMeasurement( + data=data, indices=indices, covariance_name=covariance_name + ) + + with pytest.raises( + ValueError, + match=( + "Data should have the same number of elements as the number " + "of observations supported by the window function." + ), + ): + TwoPointCWindow(XY=harmonic_two_point_xy, window=window, Cell=measure) + + +def test_two_point_xi_theta_with_invalid_data_size(real_two_point_xy: TwoPointXY): + data = np.zeros(101) + 1.1 + indices = np.arange(101) + covariance_name = "cov" + measure = TwoPointMeasurement( + data=data, indices=indices, covariance_name=covariance_name + ) + thetas = np.linspace(0.0, 1.0, 100) + + with pytest.raises( + ValueError, + match="Xis should have the same shape as thetas.", + ): + TwoPointXiTheta(XY=real_two_point_xy, thetas=thetas, xis=measure) + + +def test_two_point_measurement_invalid_data(): + data = np.array([1, 2, 3, 4, 5]).reshape(-1, 1) + indices = np.array([1, 2, 3, 4, 5]) + covariance_name = "cov" + with pytest.raises( + ValueError, + match="Data should be a 1D array.", + ): + TwoPointMeasurement(data=data, indices=indices, covariance_name=covariance_name) + + +def test_two_point_measurement_invalid_indices(): + data = np.array([1, 2, 3, 4, 5]) + indices = np.array([1, 2, 3, 4, 5]).reshape(-1, 1) + covariance_name = "cov" + with pytest.raises( + ValueError, + match="Data and indices should have the same shape.", + ): + TwoPointMeasurement(data=data, indices=indices, covariance_name=covariance_name) + + +def test_two_point_measurement_eq(): + data = np.array([1, 2, 3, 4, 5]) + indices = np.array([1, 2, 3, 4, 5]) + covariance_name = "cov" + measure_1 = TwoPointMeasurement( + data=data, indices=indices, covariance_name=covariance_name + ) + measure_2 = TwoPointMeasurement( + data=data, indices=indices, covariance_name=covariance_name + ) + assert measure_1 == measure_2 + + +def test_two_point_measurement_neq(): + data = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + indices = np.array([1, 2, 3, 4, 5]) + covariance_name = "cov" + measure_1 = TwoPointMeasurement( + data=data, + indices=indices, + covariance_name=covariance_name, + ) + measure_2 = TwoPointMeasurement(data=data, indices=indices, covariance_name="cov2") + assert measure_1 != measure_2 + measure_3 = TwoPointMeasurement( + data=data, indices=indices, covariance_name=covariance_name + ) + measure_4 = TwoPointMeasurement( + data=data, indices=indices + 1, covariance_name=covariance_name + ) + assert measure_3 != measure_4 + measure_5 = TwoPointMeasurement( + data=data, indices=indices, covariance_name=covariance_name + ) + measure_6 = TwoPointMeasurement( + data=data + 1.0, indices=indices, covariance_name=covariance_name + ) + assert measure_5 != measure_6 diff --git a/tests/metadata/test_metadata_two_point_sacc.py b/tests/metadata/test_metadata_two_point_sacc.py index 828fab43..b220f4bc 100644 --- a/tests/metadata/test_metadata_two_point_sacc.py +++ b/tests/metadata/test_metadata_two_point_sacc.py @@ -10,17 +10,27 @@ import sacc +from firecrown.parameters import ParamsMap +from firecrown.metadata.two_point_types import ( + Galaxies, + type_to_sacc_string_harmonic, + type_to_sacc_string_real, +) from firecrown.metadata.two_point import ( extract_all_data_types_cells, extract_all_data_types_xi_thetas, extract_all_photoz_bin_combinations, extract_all_tracers, extract_window_function, - Galaxies, + extract_all_data_cells, + extract_all_data_xi_thetas, + check_two_point_consistence_harmonic, + check_two_point_consistence_real, TracerNames, TwoPointCells, TwoPointXiTheta, ) +from firecrown.likelihood.two_point import TwoPoint, use_source_factory @pytest.fixture(name="sacc_galaxy_src0_src0_invalid_data_type") @@ -233,7 +243,7 @@ def test_extract_all_tracers_cells_src0_src0(sacc_galaxy_cells_src0_src0): assert_array_equal(tracer.z, z) assert_array_equal(tracer.dndz, dndz) assert tracer.bin_name == "src0" - assert tracer.measurement == Galaxies.SHEAR_E + assert tracer.measurements == {Galaxies.SHEAR_E} def test_extract_all_tracers_cells_src0_src1(sacc_galaxy_cells_src0_src1): @@ -245,7 +255,7 @@ def test_extract_all_tracers_cells_src0_src1(sacc_galaxy_cells_src0_src1): for tracer in all_tracers: assert_array_equal(tracer.z, z) - assert tracer.measurement == Galaxies.SHEAR_E + assert tracer.measurements == {Galaxies.SHEAR_E} if tracer.bin_name == "src0": assert_array_equal(tracer.dndz, dndz0) elif tracer.bin_name == "src1": @@ -263,7 +273,7 @@ def test_extract_all_tracers_cells_lens0_lens0(sacc_galaxy_cells_lens0_lens0): assert_array_equal(tracer.z, z) assert_array_equal(tracer.dndz, dndz) assert tracer.bin_name == "lens0" - assert tracer.measurement == Galaxies.COUNTS + assert tracer.measurements == {Galaxies.COUNTS} def test_extract_all_tracers_cells_lens0_lens1(sacc_galaxy_cells_lens0_lens1): @@ -275,7 +285,7 @@ def test_extract_all_tracers_cells_lens0_lens1(sacc_galaxy_cells_lens0_lens1): for tracer in all_tracers: assert_array_equal(tracer.z, z) - assert tracer.measurement == Galaxies.COUNTS + assert tracer.measurements == {Galaxies.COUNTS} if tracer.bin_name == "lens0": assert_array_equal(tracer.dndz, dndz0) elif tracer.bin_name == "lens1": @@ -293,7 +303,7 @@ def test_extract_all_tracers_xis_lens0_lens0(sacc_galaxy_xis_lens0_lens0): assert_array_equal(tracer.z, z) assert_array_equal(tracer.dndz, dndz) assert tracer.bin_name == "lens0" - assert tracer.measurement == Galaxies.COUNTS + assert tracer.measurements == {Galaxies.COUNTS} def test_extract_all_tracers_xis_lens0_lens1(sacc_galaxy_xis_lens0_lens1): @@ -305,7 +315,7 @@ def test_extract_all_tracers_xis_lens0_lens1(sacc_galaxy_xis_lens0_lens1): for tracer in all_tracers: assert_array_equal(tracer.z, z) - assert tracer.measurement == Galaxies.COUNTS + assert tracer.measurements == {Galaxies.COUNTS} if tracer.bin_name == "lens0": assert_array_equal(tracer.dndz, dndz0) elif tracer.bin_name == "lens1": @@ -323,11 +333,11 @@ def test_extract_all_trace_cells_src0_lens0(sacc_galaxy_cells_src0_lens0): if tracer.bin_name == "src0": assert_array_equal(tracer.z, z) assert_array_equal(tracer.dndz, dndz0) - assert tracer.measurement == Galaxies.SHEAR_E + assert tracer.measurements == {Galaxies.SHEAR_E} elif tracer.bin_name == "lens0": assert_array_equal(tracer.z, z) assert_array_equal(tracer.dndz, dndz1) - assert tracer.measurement == Galaxies.COUNTS + assert tracer.measurements == {Galaxies.COUNTS} def test_extract_all_trace_xis_src0_lens0(sacc_galaxy_xis_src0_lens0): @@ -341,11 +351,11 @@ def test_extract_all_trace_xis_src0_lens0(sacc_galaxy_xis_src0_lens0): if tracer.bin_name == "src0": assert_array_equal(tracer.z, z) assert_array_equal(tracer.dndz, dndz0) - assert tracer.measurement == Galaxies.SHEAR_T + assert tracer.measurements == {Galaxies.SHEAR_T} elif tracer.bin_name == "lens0": assert_array_equal(tracer.z, z) assert_array_equal(tracer.dndz, dndz1) - assert tracer.measurement == Galaxies.COUNTS + assert tracer.measurements == {Galaxies.COUNTS} def test_extract_all_tracers_invalid_data_type( @@ -366,7 +376,7 @@ def test_extract_all_tracers_bad_lens_label( assert sacc_data is not None with pytest.raises( ValueError, - match="Tracer non_informative_label does not have a compatible Measurement.", + match="Tracer src0 does not have data points associated with it.", ): _ = extract_all_tracers(sacc_data) @@ -378,7 +388,9 @@ def test_extract_all_tracers_bad_source_label( assert sacc_data is not None with pytest.raises( ValueError, - match=("Tracer non_informative_label does not have a compatible Measurement."), + match=( + "Tracer non_informative_label does not have data points associated with it." + ), ): _ = extract_all_tracers(sacc_data) @@ -390,10 +402,7 @@ def test_extract_all_tracers_inconsistent_lens_label( assert sacc_data is not None with pytest.raises( ValueError, - match=( - "Tracer lens0 matches the lens regex but does " - "not have a compatible Measurement." - ), + match=("Invalid SACC file, tracer names do not respect the naming convetion."), ): _ = extract_all_tracers(sacc_data) @@ -405,10 +414,7 @@ def test_extract_all_tracers_inconsistent_source_label( assert sacc_data is not None with pytest.raises( ValueError, - match=( - "Tracer src0 matches the source regex but does " - "not have a compatible Measurement." - ), + match=("Invalid SACC file, tracer names do not respect the naming convetion."), ): _ = extract_all_tracers(sacc_data) @@ -421,31 +427,26 @@ def test_extract_all_data_cells(sacc_galaxy_cells): for two_point in all_data: tracer_names = two_point["tracer_names"] - assert tracer_names in tracer_pairs + assert (tracer_names, two_point["data_type"]) in tracer_pairs - tracer_pair = tracer_pairs[tracer_names] - assert_array_equal(two_point["ells"], tracer_pair[1]) - assert two_point["data_type"] == tracer_pair[0] + tracer_pair = tracer_pairs[(tracer_names, two_point["data_type"])] + assert_array_equal(two_point["ells"], tracer_pair[0]) def test_extract_all_data_cells_by_type(sacc_galaxy_cells): sacc_data, _, tracer_pairs = sacc_galaxy_cells - tracer_pairs = { - k: v for k, v in tracer_pairs.items() if v[0] == "galaxy_shear_cl_ee" - } all_data = extract_all_data_types_cells( sacc_data, allowed_data_type=["galaxy_shear_cl_ee"] ) - assert len(all_data) == len(tracer_pairs) + assert len(all_data) < len(tracer_pairs) for two_point in all_data: tracer_names = two_point["tracer_names"] - assert tracer_names in tracer_pairs + assert (tracer_names, "galaxy_shear_cl_ee") in tracer_pairs - tracer_pair = tracer_pairs[tracer_names] - assert_array_equal(two_point["ells"], tracer_pair[1]) - assert two_point["data_type"] == tracer_pair[0] + tracer_pair = tracer_pairs[(tracer_names, "galaxy_shear_cl_ee")] + assert_array_equal(two_point["ells"], tracer_pair[0]) def test_extract_all_data_xis(sacc_galaxy_xis): @@ -456,32 +457,26 @@ def test_extract_all_data_xis(sacc_galaxy_xis): for two_point in all_data: tracer_names = two_point["tracer_names"] - assert tracer_names in tracer_pairs + assert (tracer_names, two_point["data_type"]) in tracer_pairs - tracer_pair = tracer_pairs[tracer_names] - assert_array_equal(two_point["thetas"], tracer_pair[1]) - assert two_point["data_type"] == tracer_pair[0] + tracer_pair = tracer_pairs[(tracer_names, two_point["data_type"])] + assert_array_equal(two_point["thetas"], tracer_pair[0]) def test_extract_all_data_xis_by_type(sacc_galaxy_xis): sacc_data, _, tracer_pairs = sacc_galaxy_xis - tracer_pairs = { - k: v for k, v in tracer_pairs.items() if v[0] == "galaxy_density_xi" - } - all_data = extract_all_data_types_xi_thetas( sacc_data, allowed_data_type=["galaxy_density_xi"] ) - assert len(all_data) == len(tracer_pairs) + assert len(all_data) < len(tracer_pairs) for two_point in all_data: tracer_names = two_point["tracer_names"] - assert tracer_names in tracer_pairs + assert (tracer_names, two_point["data_type"]) in tracer_pairs - tracer_pair = tracer_pairs[tracer_names] - assert_array_equal(two_point["thetas"], tracer_pair[1]) - assert two_point["data_type"] == tracer_pair[0] + tracer_pair = tracer_pairs[(tracer_names, two_point["data_type"])] + assert_array_equal(two_point["thetas"], tracer_pair[0]) def test_extract_no_window(sacc_galaxy_cells_src0_src0): @@ -525,19 +520,22 @@ def test_extract_all_photoz_bin_combinations_xis(sacc_galaxy_xis): # We get all combinations already present in the data tracer_names_list = [ - TracerNames(bin_comb.x.bin_name, bin_comb.y.bin_name) + ( + TracerNames(bin_comb.x.bin_name, bin_comb.y.bin_name), + type_to_sacc_string_real(bin_comb.x_measurement, bin_comb.y_measurement), + ) for bin_comb in all_bin_combs ] # We check if the particular are present in the list - for tracer_names in tracer_pairs: - assert tracer_names in tracer_names_list + for tracer_names_type in tracer_pairs: + assert tracer_names_type in tracer_names_list - bin_comb = all_bin_combs[tracer_names_list.index(tracer_names)] + bin_comb = all_bin_combs[tracer_names_list.index(tracer_names_type)] two_point_xis = TwoPointXiTheta( XY=bin_comb, thetas=np.linspace(0.0, 2.0 * np.pi, 20) ) - assert two_point_xis.get_sacc_name() == tracer_pairs[tracer_names][0] + assert two_point_xis.get_sacc_name() == tracer_names_type[1] def test_extract_all_photoz_bin_combinations_cells(sacc_galaxy_cells): @@ -547,19 +545,24 @@ def test_extract_all_photoz_bin_combinations_cells(sacc_galaxy_cells): # We get all combinations already present in the data tracer_names_list = [ - TracerNames(bin_comb.x.bin_name, bin_comb.y.bin_name) + ( + TracerNames(bin_comb.x.bin_name, bin_comb.y.bin_name), + type_to_sacc_string_harmonic( + bin_comb.x_measurement, bin_comb.y_measurement + ), + ) for bin_comb in all_bin_combs ] # We check if the particular are present in the list - for tracer_names in tracer_pairs: - assert tracer_names in tracer_names_list + for tracer_names_type in tracer_pairs: + assert tracer_names_type in tracer_names_list - bin_comb = all_bin_combs[tracer_names_list.index(tracer_names)] + bin_comb = all_bin_combs[tracer_names_list.index(tracer_names_type)] two_point_cells = TwoPointCells( XY=bin_comb, ells=np.unique(np.logspace(1, 3, 10)) ) - assert two_point_cells.get_sacc_name() == tracer_pairs[tracer_names][0] + assert two_point_cells.get_sacc_name() == tracer_names_type[1] def test_make_cells(sacc_galaxy_cells): @@ -590,3 +593,303 @@ def test_make_xis(sacc_galaxy_xis): assert_array_equal(two_point_xis.thetas, thetas) assert two_point_xis.XY is not None assert two_point_xis.XY == bin_comb + + +def test_extract_all_two_point_cells(sacc_galaxy_cells): + sacc_data, _, tracer_pairs = sacc_galaxy_cells + + two_point_cells, two_point_cwindows = extract_all_data_cells(sacc_data) + assert len(two_point_cells) + len(two_point_cwindows) == len(tracer_pairs) + + assert len(two_point_cwindows) == 0 + + for two_point in two_point_cells: + tracer_names = TracerNames(two_point.XY.x.bin_name, two_point.XY.y.bin_name) + assert (tracer_names, two_point.get_sacc_name()) in tracer_pairs + + ells, Cell = tracer_pairs[(tracer_names, two_point.get_sacc_name())] + assert_array_equal(two_point.ells, ells) + assert two_point.Cell is not None + assert_array_equal(two_point.Cell.data, Cell) + + check_two_point_consistence_harmonic(two_point_cells) + + +def test_extract_all_two_point_cwindows(sacc_galaxy_cwindows): + sacc_data, _, tracer_pairs = sacc_galaxy_cwindows + + two_point_cells, two_point_cwindows = extract_all_data_cells(sacc_data) + assert len(two_point_cells) + len(two_point_cwindows) == len(tracer_pairs) + assert len(two_point_cells) == 0 + + for two_point in two_point_cwindows: + tracer_names = TracerNames(two_point.XY.x.bin_name, two_point.XY.y.bin_name) + assert (tracer_names, two_point.get_sacc_name()) in tracer_pairs + + ells, Cell, window = tracer_pairs[(tracer_names, two_point.get_sacc_name())] + assert_array_equal(two_point.window.ells, ells) + assert two_point.Cell is not None + assert_array_equal(two_point.Cell.data, Cell) + assert two_point.window is not None + + assert_array_equal( + two_point.window.weights, window.weight / window.weight.sum(axis=0) + ) + assert_array_equal(two_point.window.ells, window.values) + + check_two_point_consistence_harmonic(two_point_cwindows) + + +def test_extract_all_data_xi_thetas(sacc_galaxy_xis): + sacc_data, _, tracer_pairs = sacc_galaxy_xis + + two_point_xis = extract_all_data_xi_thetas(sacc_data) + assert len(two_point_xis) == len(tracer_pairs) + + for two_point in two_point_xis: + tracer_names = TracerNames(two_point.XY.x.bin_name, two_point.XY.y.bin_name) + assert (tracer_names, two_point.get_sacc_name()) in tracer_pairs + + thetas, xi = tracer_pairs[(tracer_names, two_point.get_sacc_name())] + assert_array_equal(two_point.thetas, thetas) + assert two_point.xis is not None + assert_array_equal(two_point.xis.data, xi) + + check_two_point_consistence_real(two_point_xis) + + +def test_constructor_cells(sacc_galaxy_cells, wl_factory, nc_factory): + sacc_data, _, tracer_pairs = sacc_galaxy_cells + + two_point_cells, _ = extract_all_data_cells(sacc_data) + + two_points_new = TwoPoint.from_metadata_harmonic( + two_point_cells, wl_factory, nc_factory, check_consistence=True + ) + + assert two_points_new is not None + + for two_point in two_points_new: + tracer_pairs_key = (two_point.sacc_tracers, two_point.sacc_data_type) + assert tracer_pairs_key in tracer_pairs + assert two_point.ells is not None + assert_array_equal(two_point.ells, tracer_pairs[tracer_pairs_key][0]) + assert_array_equal( + two_point.get_data_vector(), tracer_pairs[tracer_pairs_key][1] + ) + + +def test_constructor_cwindows(sacc_galaxy_cwindows, wl_factory, nc_factory): + sacc_data, _, tracer_pairs = sacc_galaxy_cwindows + + _, two_point_cwindows = extract_all_data_cells(sacc_data) + + two_points_new = TwoPoint.from_metadata_harmonic( + two_point_cwindows, wl_factory, nc_factory + ) + + assert two_points_new is not None + + for two_point in two_points_new: + tracer_pairs_key = (two_point.sacc_tracers, two_point.sacc_data_type) + _, Cell, window = tracer_pairs[tracer_pairs_key] + + assert tracer_pairs_key in tracer_pairs + assert two_point.ells is None + assert_array_equal(two_point.get_data_vector(), Cell) + assert two_point.window is not None + assert_array_equal( + two_point.window.weights, + window.weight / window.weight.sum(axis=0), + ) + assert_array_equal(two_point.window.ells, window.values) + + +def test_constructor_xis(sacc_galaxy_xis, wl_factory, nc_factory): + sacc_data, _, tracer_pairs = sacc_galaxy_xis + + two_point_xis = extract_all_data_xi_thetas(sacc_data) + + two_points_new = TwoPoint.from_metadata_real( + two_point_xis, wl_factory, nc_factory, check_consistence=True + ) + + assert two_points_new is not None + + for two_point in two_points_new: + tracer_pairs_key = (two_point.sacc_tracers, two_point.sacc_data_type) + assert tracer_pairs_key in tracer_pairs + assert two_point.thetas is not None + assert_array_equal(two_point.thetas, tracer_pairs[tracer_pairs_key][0]) + assert_array_equal( + two_point.get_data_vector(), tracer_pairs[tracer_pairs_key][1] + ) + + +def test_compare_constructors_cells( + sacc_galaxy_cells, wl_factory, nc_factory, tools_with_vanilla_cosmology +): + sacc_data, _, _ = sacc_galaxy_cells + + two_point_cells, _ = extract_all_data_cells(sacc_data) + + two_points_harmonic = TwoPoint.from_metadata_harmonic( + two_point_cells, wl_factory, nc_factory + ) + + params = ParamsMap(two_points_harmonic.required_parameters().get_default_values()) + + two_points_old = [] + for cell in two_point_cells: + sacc_data_type = cell.get_sacc_name() + source0 = use_source_factory( + cell.XY.x, + cell.XY.x_measurement, + wl_factory=wl_factory, + nc_factory=nc_factory, + ) + source1 = use_source_factory( + cell.XY.y, + cell.XY.y_measurement, + wl_factory=wl_factory, + nc_factory=nc_factory, + ) + two_point = TwoPoint(sacc_data_type, source0, source1) + two_point.read(sacc_data) + two_points_old.append(two_point) + + assert len(two_points_harmonic) == len(two_points_old) + + for two_point, two_point_old in zip(two_points_harmonic, two_points_old): + assert two_point.sacc_tracers == two_point_old.sacc_tracers + assert two_point.sacc_data_type == two_point_old.sacc_data_type + assert two_point.ells is not None + assert two_point_old.ells is not None + assert_array_equal(two_point.ells, two_point_old.ells) + assert_array_equal(two_point.get_data_vector(), two_point_old.get_data_vector()) + assert two_point.sacc_indices is not None + assert two_point_old.sacc_indices is not None + assert_array_equal(two_point.sacc_indices, two_point_old.sacc_indices) + + two_point.update(params) + two_point_old.update(params) + + assert_array_equal( + two_point.compute_theory_vector(tools_with_vanilla_cosmology), + two_point_old.compute_theory_vector(tools_with_vanilla_cosmology), + ) + + +def test_compare_constructors_cwindows( + sacc_galaxy_cwindows, wl_factory, nc_factory, tools_with_vanilla_cosmology +): + sacc_data, _, _ = sacc_galaxy_cwindows + + _, two_point_cwindows = extract_all_data_cells(sacc_data) + + two_points_harmonic = TwoPoint.from_metadata_harmonic( + two_point_cwindows, wl_factory, nc_factory + ) + + params = ParamsMap(two_points_harmonic.required_parameters().get_default_values()) + + two_points_old = [] + for cwindow in two_point_cwindows: + sacc_data_type = cwindow.get_sacc_name() + source0 = use_source_factory( + cwindow.XY.x, + cwindow.XY.x_measurement, + wl_factory=wl_factory, + nc_factory=nc_factory, + ) + source1 = use_source_factory( + cwindow.XY.y, + cwindow.XY.y_measurement, + wl_factory=wl_factory, + nc_factory=nc_factory, + ) + two_point = TwoPoint(sacc_data_type, source0, source1) + two_point.read(sacc_data) + two_points_old.append(two_point) + + assert len(two_points_harmonic) == len(two_points_old) + + for two_point, two_point_old in zip(two_points_harmonic, two_points_old): + assert two_point.sacc_tracers == two_point_old.sacc_tracers + assert two_point.sacc_data_type == two_point_old.sacc_data_type + assert two_point.ells is None + assert ( + two_point_old.ells is not None + ) # The old constructor always sets the ells + assert_array_equal(two_point.get_data_vector(), two_point_old.get_data_vector()) + assert two_point.window is not None + assert two_point_old.window is not None + assert_array_equal(two_point.window.weights, two_point_old.window.weights) + assert_array_equal(two_point.window.ells, two_point_old.window.ells) + assert two_point.sacc_indices is not None + assert two_point_old.sacc_indices is not None + assert_array_equal(two_point.sacc_indices, two_point_old.sacc_indices) + + two_point.update(params) + two_point_old.update(params) + + assert_array_equal( + two_point.compute_theory_vector(tools_with_vanilla_cosmology), + two_point_old.compute_theory_vector(tools_with_vanilla_cosmology), + ) + + +def test_compare_constructors_xis( + sacc_galaxy_xis, wl_factory, nc_factory, tools_with_vanilla_cosmology +): + sacc_data, _, _ = sacc_galaxy_xis + + two_point_xis = extract_all_data_xi_thetas(sacc_data) + + two_points_real = TwoPoint.from_metadata_real(two_point_xis, wl_factory, nc_factory) + + params = ParamsMap(two_points_real.required_parameters().get_default_values()) + + two_points_old = [] + for xi in two_point_xis: + sacc_data_type = xi.get_sacc_name() + source0 = use_source_factory( + xi.XY.x, xi.XY.x_measurement, wl_factory=wl_factory, nc_factory=nc_factory + ) + source1 = use_source_factory( + xi.XY.y, xi.XY.y_measurement, wl_factory=wl_factory, nc_factory=nc_factory + ) + two_point = TwoPoint(sacc_data_type, source0, source1) + two_point.read(sacc_data) + two_points_old.append(two_point) + + assert len(two_points_real) == len(two_points_old) + + for two_point, two_point_old in zip(two_points_real, two_points_old): + assert two_point.sacc_tracers == two_point_old.sacc_tracers + assert two_point.sacc_data_type == two_point_old.sacc_data_type + assert two_point.thetas is not None + assert two_point_old.thetas is not None + assert_array_equal(two_point.thetas, two_point_old.thetas) + assert_array_equal(two_point.get_data_vector(), two_point_old.get_data_vector()) + assert two_point.sacc_indices is not None + assert two_point_old.sacc_indices is not None + assert_array_equal(two_point.sacc_indices, two_point_old.sacc_indices) + + two_point.update(params) + two_point_old.update(params) + + assert_array_equal( + two_point.compute_theory_vector(tools_with_vanilla_cosmology), + two_point_old.compute_theory_vector(tools_with_vanilla_cosmology), + ) + + +def test_extract_all_data_cells_no_cov(sacc_galaxy_cells): + sacc_data, _, _ = sacc_galaxy_cells + sacc_data.covariance = None + with pytest.raises( + ValueError, + match=("The SACC object does not have a covariance matrix."), + ): + _ = extract_all_data_cells(sacc_data, include_maybe_types=True) diff --git a/tests/metadata/test_metadata_two_point_types.py b/tests/metadata/test_metadata_two_point_types.py new file mode 100644 index 00000000..ffce7814 --- /dev/null +++ b/tests/metadata/test_metadata_two_point_types.py @@ -0,0 +1,581 @@ +""" +Tests for the module firecrown.metadata.two_point +""" + +from dataclasses import replace +from itertools import product, chain +from unittest.mock import MagicMock +import re +import pytest +import numpy as np + + +import sacc +import sacc_name_mapping as snm +from firecrown.metadata.two_point_types import ( + Galaxies, + Clusters, + CMB, + TracerNames, + TwoPointMeasurement, + ALL_MEASUREMENTS, + compare_enums, + type_to_sacc_string_harmonic as harmonic, + type_to_sacc_string_real as real, + measurement_is_compatible as is_compatible, + measurement_is_compatible_real as is_compatible_real, + measurement_is_compatible_harmonic as is_compatible_harmonic, + measurement_supports_harmonic as supports_harmonic, + measurement_supports_real as supports_real, +) + +from firecrown.metadata.two_point import ( + extract_all_tracers_types, + measurements_from_index, + LENS_REGEX, + SOURCE_REGEX, + TwoPointXiThetaIndex, + match_name_type, + check_two_point_consistence_harmonic, + check_two_point_consistence_real, + extract_all_data_cells, + extract_all_data_xi_thetas, +) + + +def test_order_enums(): + assert compare_enums(CMB.CONVERGENCE, Clusters.COUNTS) < 0 + assert compare_enums(Clusters.COUNTS, CMB.CONVERGENCE) > 0 + + assert compare_enums(CMB.CONVERGENCE, Galaxies.COUNTS) < 0 + assert compare_enums(Galaxies.COUNTS, CMB.CONVERGENCE) > 0 + + assert compare_enums(Galaxies.SHEAR_E, Galaxies.SHEAR_T) < 0 + assert compare_enums(Galaxies.SHEAR_E, Galaxies.COUNTS) < 0 + assert compare_enums(Galaxies.SHEAR_T, Galaxies.COUNTS) < 0 + + assert compare_enums(Galaxies.COUNTS, Galaxies.SHEAR_E) > 0 + + for enumerand in ALL_MEASUREMENTS: + assert compare_enums(enumerand, enumerand) == 0 + + +def test_compare_enums_wrong_type(): + with pytest.raises( + ValueError, + match=re.escape( + "Unknown measurement type encountered " + "(, )." + ), + ): + compare_enums(Galaxies.COUNTS, 1) # type: ignore + + +def test_enumeration_equality_galaxy(): + for e1, e2 in product(Galaxies, chain(CMB, Clusters)): + assert e1 != e2 + + +def test_enumeration_equality_cmb(): + for e1, e2 in product(CMB, chain(Galaxies, Clusters)): + assert e1 != e2 + + +def test_enumeration_equality_cluster(): + for e1, e2 in product(Clusters, chain(CMB, Galaxies)): + assert e1 != e2 + + +def test_exact_matches(): + for sacc_name, space, (enum_1, enum_2) in snm.mappings: + if space == "ell": + assert harmonic(enum_1, enum_2) == sacc_name + elif space == "theta": + assert real(enum_1, enum_2) == sacc_name + else: + raise ValueError(f"Illegal 'space' value {space} in testing data") + + +def test_translation_invariants(): + for a, b in product(ALL_MEASUREMENTS, ALL_MEASUREMENTS): + assert isinstance(a, (Galaxies, CMB, Clusters)) + assert isinstance(b, (Galaxies, CMB, Clusters)) + if is_compatible_real(a, b): + assert real(a, b) == real(b, a) + if is_compatible_harmonic(a, b): + assert harmonic(a, b) == harmonic(b, a) + if ( + supports_harmonic(a) + and supports_harmonic(b) + and supports_real(a) + and supports_real(b) + ): + assert harmonic(a, b) != real(a, b) + + +def test_unsupported_type_galaxy(): + unknown_type = MagicMock() + unknown_type.configure_mock(__eq__=MagicMock(return_value=False)) + + with pytest.raises(ValueError, match="Untranslated Galaxy Measurement encountered"): + Galaxies.sacc_measurement_name(unknown_type) + + with pytest.raises(ValueError, match="Untranslated Galaxy Measurement encountered"): + Galaxies.polarization(unknown_type) + + +def test_unsupported_type_cmb(): + unknown_type = MagicMock() + unknown_type.configure_mock(__eq__=MagicMock(return_value=False)) + + with pytest.raises(ValueError, match="Untranslated CMBMeasurement encountered"): + CMB.sacc_measurement_name(unknown_type) + + with pytest.raises(ValueError, match="Untranslated CMBMeasurement encountered"): + CMB.polarization(unknown_type) + + +def test_unsupported_type_cluster(): + unknown_type = MagicMock() + unknown_type.configure_mock(__eq__=MagicMock(return_value=False)) + + with pytest.raises(ValueError, match="Untranslated ClusterMeasurement encountered"): + Clusters.sacc_measurement_name(unknown_type) + + with pytest.raises(ValueError, match="Untranslated ClusterMeasurement encountered"): + Clusters.polarization(unknown_type) + + +def test_type_hashs(): + for e1, e2 in product(ALL_MEASUREMENTS, ALL_MEASUREMENTS): + if e1 == e2: + assert hash(e1) == hash(e2) + else: + assert hash(e1) != hash(e2) + + +def test_measurement_is_compatible(): + for a, b in product(ALL_MEASUREMENTS, ALL_MEASUREMENTS): + assert isinstance(a, (Galaxies, CMB, Clusters)) + assert isinstance(b, (Galaxies, CMB, Clusters)) + if is_compatible_real(a, b) or is_compatible_harmonic(a, b): + assert is_compatible(a, b) + else: + assert not is_compatible(a, b) + + +def test_extract_all_tracers_types_cells( + sacc_galaxy_cells: tuple[sacc.Sacc, dict, dict] +): + sacc_data, _, _ = sacc_galaxy_cells + + tracers = extract_all_tracers_types(sacc_data) + + for tracer, measurements in tracers.items(): + if LENS_REGEX.match(tracer): + for measurement in measurements: + assert measurement == Galaxies.COUNTS + if SOURCE_REGEX.match(tracer): + for measurement in measurements: + assert measurement == Galaxies.SHEAR_E + + +def test_extract_all_tracers_types_cwindows( + sacc_galaxy_cwindows: tuple[sacc.Sacc, dict, dict] +): + sacc_data, _, _ = sacc_galaxy_cwindows + + tracers = extract_all_tracers_types(sacc_data) + + for tracer, measurements in tracers.items(): + if LENS_REGEX.match(tracer): + for measurement in measurements: + assert measurement == Galaxies.COUNTS + if SOURCE_REGEX.match(tracer): + for measurement in measurements: + assert measurement == Galaxies.SHEAR_E + + +def test_extract_all_tracers_types_xi_thetas( + sacc_galaxy_xis: tuple[sacc.Sacc, dict, dict] +): + sacc_data, _, _ = sacc_galaxy_xis + + tracers = extract_all_tracers_types(sacc_data) + + for tracer, measurements in tracers.items(): + if LENS_REGEX.match(tracer): + for measurement in measurements: + assert measurement == Galaxies.COUNTS + if SOURCE_REGEX.match(tracer): + assert measurements == { + Galaxies.SHEAR_T, + Galaxies.SHEAR_MINUS, + Galaxies.SHEAR_PLUS, + } + + +def test_extract_all_tracers_types_xi_thetas_inverted( + sacc_galaxy_xis_inverted: tuple[sacc.Sacc, dict, dict] +): + sacc_data, _, _ = sacc_galaxy_xis_inverted + + tracers = extract_all_tracers_types(sacc_data) + + for tracer, measurements in tracers.items(): + if LENS_REGEX.match(tracer): + for measurement in measurements: + assert measurement == Galaxies.COUNTS + if SOURCE_REGEX.match(tracer): + assert measurements == { + Galaxies.SHEAR_T, + Galaxies.SHEAR_MINUS, + Galaxies.SHEAR_PLUS, + } + + +def test_extract_all_tracers_types_cells_include_maybe( + sacc_galaxy_cells: tuple[sacc.Sacc, dict, dict] +): + sacc_data, _, _ = sacc_galaxy_cells + + assert extract_all_tracers_types( + sacc_data, include_maybe_types=True + ) == extract_all_tracers_types(sacc_data) + + +def test_extract_all_tracers_types_cwindows_include_maybe( + sacc_galaxy_cwindows: tuple[sacc.Sacc, dict, dict] +): + sacc_data, _, _ = sacc_galaxy_cwindows + + assert extract_all_tracers_types( + sacc_data, include_maybe_types=True + ) == extract_all_tracers_types(sacc_data) + + +def test_extract_all_tracers_types_xi_thetas_include_maybe( + sacc_galaxy_xis: tuple[sacc.Sacc, dict, dict] +): + sacc_data, _, _ = sacc_galaxy_xis + + assert extract_all_tracers_types( + sacc_data, include_maybe_types=True + ) == extract_all_tracers_types(sacc_data) + + +def test_measurements_from_index1(): + index: TwoPointXiThetaIndex = { + "data_type": "galaxy_shearDensity_xi_t", + "tracer_names": TracerNames("src0", "lens0"), + "thetas": np.linspace(0.0, 1.0, 100), + } + n1, a, n2, b = measurements_from_index(index) + assert n1 == "lens0" + assert a == Galaxies.COUNTS + assert n2 == "src0" + assert b == Galaxies.SHEAR_T + + +def test_measurements_from_index2(): + index: TwoPointXiThetaIndex = { + "data_type": "galaxy_shearDensity_xi_t", + "tracer_names": TracerNames("lens0", "src0"), + "thetas": np.linspace(0.0, 1.0, 100), + } + n1, a, n2, b = measurements_from_index(index) + assert n1 == "lens0" + assert a == Galaxies.COUNTS + assert n2 == "src0" + assert b == Galaxies.SHEAR_T + + +def test_match_name_type1(): + match, n1, a, n2, b = match_name_type( + "src0", "lens0", Galaxies.SHEAR_T, Galaxies.COUNTS + ) + assert ( + match + and n1 == "lens0" + and a == Galaxies.COUNTS + and n2 == "src0" + and b == Galaxies.SHEAR_T + ) + + +def test_match_name_type2(): + match, n1, a, n2, b = match_name_type( + "src0", "lens0", Galaxies.COUNTS, Galaxies.SHEAR_T + ) + assert ( + match + and n1 == "lens0" + and a == Galaxies.COUNTS + and n2 == "src0" + and b == Galaxies.SHEAR_T + ) + + +def test_match_name_type3(): + match, n1, a, n2, b = match_name_type( + "lens0", "src0", Galaxies.SHEAR_T, Galaxies.COUNTS + ) + assert ( + match + and n1 == "lens0" + and a == Galaxies.COUNTS + and n2 == "src0" + and b == Galaxies.SHEAR_T + ) + + +def test_match_name_type4(): + match, n1, a, n2, b = match_name_type( + "lens0", "src0", Galaxies.COUNTS, Galaxies.SHEAR_T + ) + assert ( + match + and n1 == "lens0" + and a == Galaxies.COUNTS + and n2 == "src0" + and b == Galaxies.SHEAR_T + ) + + +def test_match_name_type_convention1(): + match, n1, a, n2, b = match_name_type( + "lens0", "no_convention", Galaxies.COUNTS, Galaxies.COUNTS + ) + assert not match + assert n1 == "lens0" + assert a == Galaxies.COUNTS + assert n2 == "no_convention" + assert b == Galaxies.COUNTS + + +def test_match_name_type_convention2(): + match, n1, a, n2, b = match_name_type( + "no_convention", "lens0", Galaxies.COUNTS, Galaxies.COUNTS + ) + assert not match + assert n1 == "no_convention" + assert a == Galaxies.COUNTS + assert n2 == "lens0" + assert b == Galaxies.COUNTS + + +def test_match_name_type_convention3(): + match, n1, a, n2, b = match_name_type( + "no_convention", "here_too", Galaxies.COUNTS, Galaxies.SHEAR_T + ) + assert not match + assert n1 == "no_convention" + assert a == Galaxies.COUNTS + assert n2 == "here_too" + assert b == Galaxies.SHEAR_T + + +def test_match_name_type_require_convention_fail(): + with pytest.raises( + ValueError, + match="Invalid tracer names (.*) do not respect the naming convetion.", + ): + match_name_type( + "no_convention", + "here_too", + Galaxies.COUNTS, + Galaxies.SHEAR_T, + require_convetion=True, + ) + + +def test_match_name_type_require_convention_lens(): + match, n1, a, n2, b = match_name_type( + "lens0", + "lens0", + Galaxies.COUNTS, + Galaxies.COUNTS, + require_convetion=True, + ) + assert not match + assert n1 == "lens0" + assert a == Galaxies.COUNTS + assert n2 == "lens0" + assert b == Galaxies.COUNTS + + +def test_match_name_type_require_convention_source(): + match, n1, a, n2, b = match_name_type( + "src0", + "src0", + Galaxies.SHEAR_MINUS, + Galaxies.SHEAR_MINUS, + require_convetion=True, + ) + assert not match + assert n1 == "src0" + assert a == Galaxies.SHEAR_MINUS + assert n2 == "src0" + assert b == Galaxies.SHEAR_MINUS + + +def test_check_two_point_consistence_harmonic(two_point_cell): + Cell = TwoPointMeasurement( + data=np.zeros(100), indices=np.arange(100), covariance_name="cov" + ) + check_two_point_consistence_harmonic([replace(two_point_cell, Cell=Cell)]) + + +def test_check_two_point_consistence_harmonic_missing_cell(two_point_cell): + with pytest.raises( + ValueError, + match="The TwoPointCells \\(.*, .*\\)\\[.*\\] does not contain a data.", + ): + check_two_point_consistence_harmonic([two_point_cell]) + + +def test_check_two_point_consistence_real(two_point_xi_theta): + xis = TwoPointMeasurement( + data=np.zeros(100), indices=np.arange(100), covariance_name="cov" + ) + check_two_point_consistence_real([replace(two_point_xi_theta, xis=xis)]) + + +def test_check_two_point_consistence_real_missing_xis(two_point_xi_theta): + with pytest.raises( + ValueError, + match="The TwoPointXiTheta \\(.*, .*\\)\\[.*\\] does not contain a data.", + ): + check_two_point_consistence_real([two_point_xi_theta]) + + +def test_check_two_point_consistence_harmonic_mixing_cov(sacc_galaxy_cells): + sacc_data, _, _ = sacc_galaxy_cells + + two_point_cells, _ = extract_all_data_cells(sacc_data) + + assert two_point_cells[0].Cell is not None + two_point_cells[0] = replace( + two_point_cells[0], + Cell=replace(two_point_cells[0].Cell, covariance_name="wrong_cov_name"), + ) + + with pytest.raises( + ValueError, + match=( + "The TwoPointCells .* has a different covariance name .* " + "than the previous TwoPointCells wrong_cov_name." + ), + ): + check_two_point_consistence_harmonic(two_point_cells) + + +def test_check_two_point_consistence_real_mixing_cov(sacc_galaxy_xis): + sacc_data, _, _ = sacc_galaxy_xis + + two_point_xis = extract_all_data_xi_thetas(sacc_data) + assert two_point_xis[0].xis is not None + two_point_xis[0] = replace( + two_point_xis[0], + xis=replace(two_point_xis[0].xis, covariance_name="wrong_cov_name"), + ) + + with pytest.raises( + ValueError, + match=( + "The TwoPointXiTheta .* has a different covariance name .* than the " + "previous TwoPointXiTheta wrong_cov_name." + ), + ): + check_two_point_consistence_real(two_point_xis) + + +def test_check_two_point_consistence_harmonic_non_unique_indices(sacc_galaxy_cells): + sacc_data, _, _ = sacc_galaxy_cells + + two_point_cells, _ = extract_all_data_cells(sacc_data) + + assert two_point_cells[0].Cell is not None + new_indices = two_point_cells[0].Cell.indices + new_indices[0] = 3 + two_point_cells[0] = replace( + two_point_cells[0], + Cell=replace(two_point_cells[0].Cell, indices=new_indices), + ) + + with pytest.raises( + ValueError, + match="The indices of the TwoPointCells .* are not unique.", + ): + check_two_point_consistence_harmonic(two_point_cells) + + +def test_check_two_point_consistence_real_non_unique_indices(sacc_galaxy_xis): + sacc_data, _, _ = sacc_galaxy_xis + + two_point_xis = extract_all_data_xi_thetas(sacc_data) + assert two_point_xis[0].xis is not None + new_indices = two_point_xis[0].xis.indices + new_indices[0] = 3 + two_point_xis[0] = replace( + two_point_xis[0], + xis=replace(two_point_xis[0].xis, indices=new_indices), + ) + + with pytest.raises( + ValueError, + match="The indices of the TwoPointXiTheta .* are not unique.", + ): + check_two_point_consistence_real(two_point_xis) + + +def test_check_two_point_consistence_harmonic_indices_overlap(sacc_galaxy_cells): + sacc_data, _, _ = sacc_galaxy_cells + + two_point_cells, _ = extract_all_data_cells(sacc_data) + + assert two_point_cells[1].Cell is not None + new_indices = two_point_cells[1].Cell.indices + new_indices[1] = 3 + two_point_cells[1] = replace( + two_point_cells[1], + Cell=replace(two_point_cells[1].Cell, indices=new_indices), + ) + + with pytest.raises( + ValueError, + match="The indices of the TwoPointCells .* overlap.", + ): + check_two_point_consistence_harmonic(two_point_cells) + + +def test_check_two_point_consistence_real_indices_overlap(sacc_galaxy_xis): + sacc_data, _, _ = sacc_galaxy_xis + + two_point_xis = extract_all_data_xi_thetas(sacc_data) + assert two_point_xis[1].xis is not None + new_indices = two_point_xis[1].xis.indices + new_indices[1] = 3 + two_point_xis[1] = replace( + two_point_xis[1], + xis=replace(two_point_xis[1].xis, indices=new_indices), + ) + + with pytest.raises( + ValueError, + match="The indices of the TwoPointXiTheta .* overlap.", + ): + check_two_point_consistence_real(two_point_xis) + + +def test_extract_all_data_cells_ambiguous(sacc_galaxy_cells_ambiguous): + with pytest.raises( + ValueError, + match=( + "Ambiguous measurements for tracers .*. Impossible to " + "determine which measurement is from which tracer." + ), + ): + _ = extract_all_data_cells( + sacc_galaxy_cells_ambiguous, include_maybe_types=True + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index c15ce94f..238e9476 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,6 +9,7 @@ upper_triangle_indices, save_to_sacc, compare_optional_arrays, + compare_optionals, base_model_from_yaml, ) @@ -74,3 +75,21 @@ def test_compare_optional_arrays_(): def test_base_model_from_yaml_wrong(): with pytest.raises(ValueError): _ = base_model_from_yaml(str, "wrong") + + +def test_compare_optionals(): + x = "test" + y = "test" + assert compare_optionals(x, y) + + z = "test2" + assert not compare_optionals(x, z) + + a = None + b = None + assert compare_optionals(a, b) + + q = np.array([1, 2, 3]) + assert not compare_optionals(q, a) + + assert not compare_optionals(a, q) diff --git a/tutorial/_quarto.yml b/tutorial/_quarto.yml index b7169d98..8c8d6ca4 100644 --- a/tutorial/_quarto.yml +++ b/tutorial/_quarto.yml @@ -10,6 +10,7 @@ project: - inferred_zdist_generators.qmd - inferred_zdist_serialization.qmd - two_point_generators.qmd + - two_point_factories.qmd website: title: "Firecrown Tutorial Documents" @@ -37,6 +38,8 @@ website: text: "Redshift Distribution Serialization" - href: two_point_generators.qmd text: "Two-Point Function Generators" + - href: two_point_factories.qmd + text: "Two-Point Function Factories" - section: "Firecrown on GitHub" contents: - text: "Source Code" diff --git a/tutorial/inferred_zdist.qmd b/tutorial/inferred_zdist.qmd index a1a24511..35d001a5 100644 --- a/tutorial/inferred_zdist.qmd +++ b/tutorial/inferred_zdist.qmd @@ -25,7 +25,7 @@ The `InferredGalaxyZDist` object serves as the cornerstone for representing the The `InferredGalaxyZDist` object can be created by providing the following parameters: ```{python} -from firecrown.metadata.two_point import Galaxies, InferredGalaxyZDist +from firecrown.metadata.two_point_types import Galaxies, InferredGalaxyZDist import numpy as np z = np.linspace(0.0, 1.0, 200) @@ -33,7 +33,7 @@ lens0 = InferredGalaxyZDist( bin_name="lens0", z=z, dndz=np.exp(-0.5 * ((z - 0.5) / 0.02) ** 2) / np.sqrt(2 * np.pi) / 0.02, - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) ``` @@ -54,15 +54,14 @@ from firecrown.generators.two_point import LogLinearElls lens0_lens0 = TwoPointXY( x=lens0, y=lens0, + x_measurement=Galaxies.COUNTS, + y_measurement=Galaxies.COUNTS, ) # Note that we are leaving out the monopole (l=0) and dipole (l=1) # terms. Firecrown allows their inclusion, if you so desire. ells_generator = LogLinearElls(minimum=2, midpoint=20, maximum=200, n_log=20) -lens0_lens0_cell = TwoPointCells( - XY=lens0_lens0, - ells=ells_generator.generate() -) +lens0_lens0_cell = TwoPointCells(XY=lens0_lens0, ells=ells_generator.generate()) ``` The `TwoPointCells` is created by providing the `TwoPointXY` object and the ells array. @@ -78,8 +77,8 @@ Using the `lens0_lens0_cell` object created above, we can create a `TwoPoint` ob from firecrown.likelihood.two_point import TwoPoint, NumberCountsFactory # Generate a list of TwoPoint objects, one for each bin provided -all_two_points = TwoPoint.from_metadata_cells( - [lens0_lens0_cell], # Note we are passing a list with 1 bin; you could pass several +all_two_points = TwoPoint.from_metadata_harmonic( + [lens0_lens0_cell], # Note we are passing a list with 1 bin; you could pass several nc_factory=NumberCountsFactory(global_systematics=[], per_bin_systematics=[]), ) two_point_lens0_lens0 = all_two_points[0] diff --git a/tutorial/inferred_zdist_generators.qmd b/tutorial/inferred_zdist_generators.qmd index e9c1cb10..92bc1415 100644 --- a/tutorial/inferred_zdist_generators.qmd +++ b/tutorial/inferred_zdist_generators.qmd @@ -118,7 +118,7 @@ from firecrown.generators.inferred_galaxy_zdist import ( Y1_SOURCE_BINS, Y10_SOURCE_BINS, ) -from firecrown.metadata.two_point import Measurement, Galaxies +from firecrown.metadata.two_point_types import Measurement, Galaxies # These are the values at which we will evaluate the distribution. z = np.linspace(0.0, 0.6, 100) @@ -131,7 +131,7 @@ Pz_lens0_y1 = zdist_y1.binned_distribution( sigma_z=Y1_LENS_BINS["sigma_z"], z=z, name="lens0_y1", - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) Pz_source0_y1 = zdist_y1.binned_distribution( zpl=Y1_SOURCE_BINS["edges"][0], @@ -139,7 +139,7 @@ Pz_source0_y1 = zdist_y1.binned_distribution( sigma_z=Y1_SOURCE_BINS["sigma_z"], z=z, name="source0_y1", - measurement=Galaxies.SHEAR_E, + measurements={Galaxies.SHEAR_E}, ) Pz_lens0_y10 = zdist_y10.binned_distribution( zpl=Y10_LENS_BINS["edges"][0], @@ -147,7 +147,7 @@ Pz_lens0_y10 = zdist_y10.binned_distribution( sigma_z=Y10_LENS_BINS["sigma_z"], z=z, name="lens0_y10", - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) Pz_source0_y10 = zdist_y10.binned_distribution( zpl=Y10_SOURCE_BINS["edges"][0], @@ -155,7 +155,7 @@ Pz_source0_y10 = zdist_y10.binned_distribution( sigma_z=Y10_SOURCE_BINS["sigma_z"], z=z, name="source0_y10", - measurement=Galaxies.SHEAR_E, + measurements={Galaxies.SHEAR_E}, ) # Next we check that the objects we created are of the expected type. @@ -240,7 +240,7 @@ all_y1_bins = [ sigma_z=Y1_LENS_BINS["sigma_z"], z=z, name=f"lens_{zpl:.1f}_{zpu:.1f}_y1", - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, use_autoknot=True, autoknots_reltol=1.0e-5, ) @@ -252,7 +252,7 @@ all_y1_bins = [ sigma_z=Y1_SOURCE_BINS["sigma_z"], z=z, name=f"source_{zpl:.1f}_{zpu:.1f}_y1", - measurement=Galaxies.SHEAR_E, + measurements={Galaxies.SHEAR_E}, use_autoknot=True, autoknots_reltol=1.0e-5, ) @@ -276,7 +276,7 @@ d_y1 = pd.concat( "z": Pz.z, "dndz": Pz.dndz, "bin": Pz.bin_name, - "measurement": Pz.measurement, + "measurement": list(Pz.measurements)[0], "legend": f"{Pz.bin_name}, {len(Pz.z)}", } ) diff --git a/tutorial/inferred_zdist_serialization.qmd b/tutorial/inferred_zdist_serialization.qmd index 4447b4ad..96d4c16b 100644 --- a/tutorial/inferred_zdist_serialization.qmd +++ b/tutorial/inferred_zdist_serialization.qmd @@ -27,7 +27,7 @@ The code snippet below demonstrates how to initialize and serialize this datacla ```{python} from firecrown.generators.inferred_galaxy_zdist import ZDistLSSTSRDBin, LinearGrid1D -from firecrown.metadata.two_point import Galaxies +from firecrown.metadata.two_point_types import Galaxies from firecrown.utils import base_model_to_yaml z = LinearGrid1D(start=0.01, end=0.5, num=20) @@ -38,7 +38,7 @@ bin0 = ZDistLSSTSRDBin( sigma_z=0.03, z=z, bin_name="bin0", - measurement=Galaxies.COUNTS, + measurements={Galaxies.COUNTS}, ) bin0_yaml = base_model_to_yaml(bin0) @@ -112,7 +112,7 @@ assert bin_collection.bins[0].zpl == bin_collection_read.bins[0].zpl assert bin_collection.bins[0].zpu == bin_collection_read.bins[0].zpu assert bin_collection.bins[0].sigma_z == bin_collection_read.bins[0].sigma_z assert bin_collection.bins[0].bin_name == bin_collection_read.bins[0].bin_name -assert bin_collection.bins[0].measurement == bin_collection_read.bins[0].measurement +assert bin_collection.bins[0].measurements == bin_collection_read.bins[0].measurements assert bin_collection.bins[0].z == bin_collection_read.bins[0].z ``` diff --git a/tutorial/two_point_factories.qmd b/tutorial/two_point_factories.qmd new file mode 100644 index 00000000..a0bb2faf --- /dev/null +++ b/tutorial/two_point_factories.qmd @@ -0,0 +1,206 @@ +--- +title: "Using Firecrown Factories to Initialize Two-Point Objects" +format: html +--- + +{{< include _functions.qmd >}} + +## Purpose of this Document + +In the tutorial [Two Point Generators](two_point_generators.qmd) we show how to use generators to create the necessary metadata to instantiate `TwoPoint` objects. +In addition to generators, the user has also the option to use data already saved in a `SACC` object (or file) to initialize the `TwoPoint` objects. +In this tutorial we show how to extract the metadata and data from `SACC` object and use the metadata and a factory to initialize the objects. + +## `SACC` objects + +The `SACC` objects are used to store all necessary information for a statistical analysis. +In practice, it stores metadata (data layout, data type, bins, dependent variables, etc), calibration data (redshift distribution per bin $\mathrm{d}n/\mathrm{d}z$) and the covariance between all measurements. +In the tutorials [`InferredGalaxyZDist`](inferred_zdist.qmd), [`InferredGalaxyZDist` Generators](inferred_zdist_generators.qmd) and [`InferredGalaxyZDist` Serialization](inferred_zdist_serialization.qmd) we show how to use, generate and serialize the `InferredGalaxyZDist` objects describing the redshift distributions. In this document we extract these components from a `SACC` object. + +## Metadata only `SACC` + +Up to the version 1.7 Firecrown relied on a two phase construction for its likelihoods. +During the first phase, the `Likelihood`, `Statistics` and `Systematics` objects are created with metadata only. +Then, the user must call the `read` method passing a `SACC` object to finish the object construction. + +One shortcoming of this methodology is that the user must know the `SACC` object struct beforehand, so they can create the necessary objects. +Then, a matching `SACC` object must be passed in the `read` phase. +To simplify this, we introduced functions to extract the metadata from the `SACC` file, such that it can be used toguether with the factories to create the objects. + +In the code below we extract the metadata from the `examples/des_y1_3x2pt` `SACC` file. +```{python} +from firecrown.metadata.two_point import extract_all_data_types_xi_thetas +import sacc + +sacc_data = sacc.Sacc.load_fits("../examples/des_y1_3x2pt/des_y1_3x2pt_sacc_data.fits") + +all_meta = extract_all_data_types_xi_thetas(sacc_data) +``` + +The metadata can be seem below: +```{python} +# | code-fold: true +import yaml +from IPython.display import Markdown + +all_meta_prune = [ + { + "data_type": meta["data_type"], + "tracer1": str(meta["tracer_names"].name1), + "tracer2": str(meta["tracer_names"].name2), + } + for meta in all_meta +] + +all_meta_yaml = yaml.safe_dump(all_meta_prune, default_flow_style=False) + +Markdown(f"```yaml\n{all_meta_yaml}\n```") +``` + +The tracer names above are ones from the `SACC` object. +The user needs to know the tracer names to create the `TwoPoint` objects. +However, now the user can use the metadata to create the objects. + +### Factories + +The factories are used to create the `TwoPoint` objects. +In Firecrown, we write factories using Pydantic, a library to validate and parse data. +This gives us the advantage of having a schema for the data, and the user can use the factories to validate the data before creating the objects. + +For example, we can organize the factories in a YAML file: +```{python} + +from firecrown.likelihood.two_point import TwoPoint +from firecrown.likelihood.weak_lensing import WeakLensingFactory +from firecrown.likelihood.number_counts import NumberCountsFactory +from firecrown.utils import base_model_from_yaml + +weak_lensing_yaml = """ +per_bin_systematics: +- type: MultiplicativeShearBiasFactory +- type: PhotoZShiftFactory +global_systematics: +- type: LinearAlignmentSystematicFactory + alphag: 1.0 +""" +number_counts_yaml = """ +per_bin_systematics: +- type: PhotoZShiftFactory +global_systematics: [] +""" + +weak_lensing_factory = base_model_from_yaml(WeakLensingFactory, weak_lensing_yaml) +number_counts_factory = base_model_from_yaml(NumberCountsFactory, number_counts_yaml) + +two_point_list = TwoPoint.from_metadata_only_real( + metadata=all_meta, wl_factory=weak_lensing_factory, nc_factory=number_counts_factory +) +``` + +### Creating a `Likelihood` object + +Now we can create a `Likelihood` object using the `TwoPoint` objects. +Since we are using the metadata only, we need to pass the `SACC` object to the `read` method. +```{python} +from firecrown.likelihood.gaussian import ConstGaussian + +likelihood = ConstGaussian(two_point_list) + +likelihood.read(sacc_data) + +``` + +## Extracting the Metadata and Data + +In the previous section, we demonstrated how to extract metadata from a `SACC object`. +Now, we will show how to extract all components from the `SACC` object, including the metadata, data, and calibration data. + +In the upcoming interface for Firecrown, likelihood objects will be created using the complete data. +This data may be extracted from a `SACC` object, generated by a custom generator, or provided in any other manner by the user. + +In the code below we extract all components from the `examples/des_y1_3x2pt` `SACC` file. +```{python} +from firecrown.metadata.two_point import extract_all_data_xi_thetas +from pprint import pprint + +two_point_xi_thetas = extract_all_data_xi_thetas(sacc_data) +``` + +The list `two_point_xi_thetas` contains all information about real-space two point functions contained in the `SACC` object. +Therefore, we need to use a different constructor to obtain the two-point objects already in the `ready` state. +```{python} +from firecrown.likelihood.two_point import TwoPoint + +weak_lensing_factory = base_model_from_yaml(WeakLensingFactory, weak_lensing_yaml) +number_counts_factory = base_model_from_yaml(NumberCountsFactory, number_counts_yaml) + +two_points_ready = TwoPoint.from_metadata_real( + two_point_xi_thetas, + wl_factory=weak_lensing_factory, + nc_factory=number_counts_factory, + check_consistence=True, +) + +``` +Note that we used the same factories to instantiate our `TwoPoint` objects. + +Now, to create the likelihood we need to use a different constructor that support creating likelihoods in `ready` state. +```{python} +from firecrown.likelihood.gaussian import ConstGaussian + +likelihood_ready = ConstGaussian.create_ready( + two_points_ready, sacc_data.covariance.dense +) +``` +Note that, since we are creating a `Likelihood` in ready state, the constructor requires the statistics covariance. + +## Comparing results + +We can now compare whether the two methods yield the same results. +Before we do that, we need to obtain our required parameters. +```{python} +from firecrown.parameters import ParamsMap + +req_params = likelihood.required_parameters() +req_params_ready = likelihood_ready.required_parameters() + +assert req_params_ready == req_params + +default_values = req_params.get_default_values() +params = ParamsMap(default_values) + +``` +Note that the required parameters are the same for both likelihoods. +Before generating the two-point statistics, we use the following default values: +```{python} +# | code-fold: true +import yaml + +default_values_yaml = yaml.dump(default_values, default_flow_style=False) + +Markdown(f"```yaml\n{default_values_yaml}\n```") +``` + + +Finally, we can prepare both likelihoods and compare the loglike values. +```{python} +import pyccl + +from firecrown.modeling_tools import ModelingTools + +ccl_cosmo = pyccl.CosmologyVanillaLCDM() +ccl_cosmo.compute_nonlin_power() + +tools = ModelingTools() +tools.update(params) +tools.prepare(ccl_cosmo) + +likelihood.update(params) +likelihood_ready.update(params) +``` +```{python} +# | code-fold: true +print(f"Loglike from metadata only: {likelihood.compute_loglike(tools)}") +print(f"Loglike from ready state: {likelihood_ready.compute_loglike(tools)}") +``` + diff --git a/tutorial/two_point_generators.qmd b/tutorial/two_point_generators.qmd index ff09ad44..bb5da5ac 100644 --- a/tutorial/two_point_generators.qmd +++ b/tutorial/two_point_generators.qmd @@ -138,7 +138,7 @@ These factories are responsible for producing the necessary systematics applied Below is the code defining the factories for weak lensing and number counts systematics: ```{python} -all_two_point_functions = tp.TwoPoint.from_metadata_cells( +all_two_point_functions = tp.TwoPoint.from_metadata_harmonic( metadata=all_two_point_cells, wl_factory=wlf, nc_factory=ncf, From bd7047635cecdae3f08f34833d5e93e51b6010b2 Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Tue, 13 Aug 2024 13:45:49 -0500 Subject: [PATCH 04/10] Yet even more docstrings (#445) * Improve docstrings and type annotations * Fix bug in NumberCounts property computation --- firecrown/likelihood/__init__.py | 2 +- .../binned_cluster_number_counts.py | 43 ++++-- firecrown/likelihood/gauss_family/__init__.py | 2 + .../gauss_family/statistic/source/__init__.py | 2 + firecrown/likelihood/gaussfamily.py | 85 ++++++++++-- firecrown/likelihood/gaussian.py | 13 +- firecrown/likelihood/likelihood.py | 90 +++++++++--- firecrown/likelihood/number_counts.py | 129 +++++++++++++----- firecrown/likelihood/supernova.py | 2 +- firecrown/likelihood/two_point.py | 5 +- .../gauss_family/statistic/test_statistic.py | 10 +- 11 files changed, 294 insertions(+), 89 deletions(-) diff --git a/firecrown/likelihood/__init__.py b/firecrown/likelihood/__init__.py index 636d36e7..0a6eede3 100644 --- a/firecrown/likelihood/__init__.py +++ b/firecrown/likelihood/__init__.py @@ -1,4 +1,4 @@ -"""Classes used to represent likelihoods and functions to support them. +"""Classes used to represent likelihoods, and functions to support them. Subpackages contain specific likelihood implementations, e.g., Gaussian and Student-t. The submodule :mod:`firecrown.likelihood.likelihood` contain the abstract base class for diff --git a/firecrown/likelihood/binned_cluster_number_counts.py b/firecrown/likelihood/binned_cluster_number_counts.py index 2732134b..00ab0398 100644 --- a/firecrown/likelihood/binned_cluster_number_counts.py +++ b/firecrown/likelihood/binned_cluster_number_counts.py @@ -1,8 +1,4 @@ -"""This module holds classes needed to predict the binned cluster number counts. - -The binned cluster number counts statistic predicts the number of galaxy -clusters within a single redshift and mass bin. -""" +"""Binned cluster number counts statistic support.""" from __future__ import annotations @@ -26,11 +22,7 @@ class BinnedClusterNumberCounts(Statistic): - """The Binned Cluster Number Counts statistic. - - This class will make a prediction for the number of clusters in a z, mass bin - and compare that prediction to the data provided in the sacc file. - """ + """A statistic representing the number of clusters in a z, mass bin.""" def __init__( self, @@ -39,6 +31,13 @@ def __init__( cluster_recipe: ClusterRecipe, systematics: None | list[SourceSystematic] = None, ): + """Initialize this statistic. + + :param cluster_properties: The cluster observables to use. + :param survey_name: The name of the survey to use. + :param cluster_recipe: The cluster recipe to use. + :param systematics: The systematics to apply to this statistic. + """ super().__init__() self.systematics = systematics or [] self.theory_vector: None | TheoryVector = None @@ -50,7 +49,10 @@ def __init__( self.bins: list[SaccBin] = [] def read(self, sacc_data: sacc.Sacc) -> None: - """Read the data for this statistic and mark it as ready for use.""" + """Read the data for this statistic and mark it as ready for use. + + :param sacc_data: The data in the sacc format. + """ # Build the data vector and indices needed for the likelihood if self.cluster_properties == ClusterProperty.NONE: raise ValueError("You must specify at least one cluster property.") @@ -77,12 +79,19 @@ def read(self, sacc_data: sacc.Sacc) -> None: super().read(sacc_data) def get_data_vector(self) -> DataVector: - """Gets the statistic data vector.""" + """Gets the statistic data vector. + + :return: The statistic data vector. + """ assert self.data_vector is not None return self.data_vector def _compute_theory_vector(self, tools: ModelingTools) -> TheoryVector: - """Compute a statistic from sources, concrete implementation.""" + """Compute a statistic from sources, concrete implementation. + + :param tools: The modeling tools used to compute the statistic. + :return: The computed statistic. + """ assert tools.cluster_abundance is not None theory_vector_list: list[float] = [] @@ -116,6 +125,9 @@ def get_binned_cluster_property( Using the data from the sacc file, this function evaluates the likelihood for a single point of the parameter space, and returns the predicted mean mass of the clusters in each bin. + + :param cluster_counts: The number of clusters in each bin. + :param cluster_properties: The cluster observables to use. """ assert tools.cluster_abundance is not None @@ -124,8 +136,6 @@ def get_binned_cluster_property( total_observable = self.cluster_recipe.evaluate_theory_prediction( tools.cluster_abundance, this_bin, self.sky_area, cluster_properties ) - cluster_counts.append(counts) - mean_observable = total_observable / counts mean_values.append(mean_observable) @@ -137,6 +147,9 @@ def get_binned_cluster_counts(self, tools: ModelingTools) -> list[float]: Using the data from the sacc file, this function evaluates the likelihood for a single point of the parameter space, and returns the predicted number of clusters in each bin. + + :param tools: The modeling tools used to compute the statistic. + :return: The number of clusters in each bin. """ assert tools.cluster_abundance is not None diff --git a/firecrown/likelihood/gauss_family/__init__.py b/firecrown/likelihood/gauss_family/__init__.py index 9c0fa90a..34815d38 100644 --- a/firecrown/likelihood/gauss_family/__init__.py +++ b/firecrown/likelihood/gauss_family/__init__.py @@ -1 +1,3 @@ +"""Backward compatibility support for deprecated directory structure.""" + # flake8: noqa diff --git a/firecrown/likelihood/gauss_family/statistic/source/__init__.py b/firecrown/likelihood/gauss_family/statistic/source/__init__.py index 9c0fa90a..34815d38 100644 --- a/firecrown/likelihood/gauss_family/statistic/source/__init__.py +++ b/firecrown/likelihood/gauss_family/statistic/source/__init__.py @@ -1 +1,3 @@ +"""Backward compatibility support for deprecated directory structure.""" + # flake8: noqa diff --git a/firecrown/likelihood/gaussfamily.py b/firecrown/likelihood/gaussfamily.py index 80ff5bbc..26516382 100644 --- a/firecrown/likelihood/gaussfamily.py +++ b/firecrown/likelihood/gaussfamily.py @@ -1,4 +1,4 @@ -"""Support for the family of Gaussian likelihood.""" +"""Support for the family of Gaussian likelihoods.""" from __future__ import annotations @@ -30,7 +30,12 @@ class State(Enum): - """The states used in GaussFamily.""" + """The states used in GaussFamily. + + GaussFamily and all subclasses enforce a statemachine behavior based on + these states to ensure that the necessary initialization and setup is done + in the correct order. + """ INITIALIZED = 1 READY = 2 @@ -62,6 +67,12 @@ def enforce_states( If terminal is None the state of the object is not modified. If terminal is not None and the call to the wrapped method returns normally the state of the object is set to terminal. + + :param initial: The initial states allowable for the wrapped method + :param terminal: The terminal state ensured for the wrapped method. None + indicates no state change happens. + :param failure_message: The failure message for the AssertionError raised + :return: The wrapped method """ initials: list[State] if isinstance(initial, list): @@ -74,6 +85,9 @@ def decorator_enforce_states(func: Callable[P, T]) -> Callable[P, T]: This closure is what actually contains the values of initials, terminal, and failure_message. + + :param func: The method to be wrapped + :return: The wrapped method """ @wraps(func) @@ -132,8 +146,11 @@ class GaussFamily(Likelihood): def __init__( self, statistics: Sequence[Statistic], - ): - """Initialize the base class parts of a GaussFamily object.""" + ) -> None: + """Initialize the base class parts of a GaussFamily object. + + :param statistics: A list of statistics to be include in chisquared calculations + """ super().__init__() self.state: State = State.INITIALIZED if len(statistics) == 0: @@ -160,7 +177,12 @@ def __init__( def create_ready( cls, statistics: Sequence[Statistic], covariance: npt.NDArray[np.float64] ) -> GaussFamily: - """Create a GaussFamily object in the READY state.""" + """Create a GaussFamily object in the READY state. + + :param statistics: A list of statistics to be include in chisquared calculations + :param covariance: The covariance matrix of the statistics + :return: A ready GaussFamily object + """ obj = cls(statistics) obj._set_covariance(covariance) obj.state = State.READY @@ -178,6 +200,8 @@ def _update(self, _: ParamsMap) -> None: for its own reasons must be sure to do what this does: check the state at the start of the method, and change the state at the end of the method. + + :param _: a ParamsMap object, not used """ @enforce_states( @@ -201,7 +225,10 @@ def _reset(self) -> None: failure_message="read() must only be called once", ) def read(self, sacc_data: sacc.Sacc) -> None: - """Read the covariance matrix for this likelihood from the SACC file.""" + """Read the covariance matrix for this likelihood from the SACC file. + + :param sacc_data: The SACC data object to be read + """ if sacc_data.covariance is None: msg = ( f"The {type(self).__name__} likelihood requires a covariance, " @@ -216,11 +243,13 @@ def read(self, sacc_data: sacc.Sacc) -> None: self._set_covariance(covariance) - def _set_covariance(self, covariance): + def _set_covariance(self, covariance: npt.NDArray[np.float64]) -> None: """Set the covariance matrix. This method is used to set the covariance matrix and perform the necessary calculations to prepare the likelihood for computation. + + :param covariance: The covariance matrix for this likelihood """ indices_list = [] data_vector_list = [] @@ -276,6 +305,7 @@ def get_cov( :param statistic: The statistic for which the sub-covariance matrix should be returned. If not specified, return the covariance of all statistics. + :return: The covariance matrix (or portion thereof) """ assert self.cov is not None if statistic is None: @@ -301,7 +331,10 @@ def get_cov( failure_message="read() must be called before get_data_vector()", ) def get_data_vector(self) -> npt.NDArray[np.float64]: - """Get the data vector from all statistics in the right order.""" + """Get the data vector from all statistics in the right order. + + :return: The data vector + """ assert self.data_vector is not None return self.data_vector @@ -315,6 +348,7 @@ def compute_theory_vector(self, tools: ModelingTools) -> npt.NDArray[np.float64] """Computes the theory vector using the current instance of pyccl.Cosmology. :param tools: Current ModelingTools object + :return: The computed theory vector """ theory_vector_list: list[npt.NDArray[np.float64]] = [ stat.compute_theory_vector(tools) for stat in self.statistics @@ -329,7 +363,10 @@ def compute_theory_vector(self, tools: ModelingTools) -> npt.NDArray[np.float64] "get_theory_vector()", ) def get_theory_vector(self) -> npt.NDArray[np.float64]: - """Get the theory vector from all statistics in the right order.""" + """Get the already-computed theory vector from all statistics. + + :return: The theory vector, with all statistics in the right order + """ assert ( self.theory_vector is not None ), "theory_vector is None after compute_theory_vector() has been called" @@ -343,7 +380,14 @@ def get_theory_vector(self) -> npt.NDArray[np.float64]: def compute( self, tools: ModelingTools ) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: - """Calculate and return both the data and theory vectors.""" + """Calculate and return both the data and theory vectors. + + This method is dprecated and will be removed in a future version of Firecrown. + + :param tools: the ModelingTools to be used in the calculation of the + theory vector + :return: a tuple containing the data vector and the theory vector + """ warnings.warn( "The use of the `compute` method on Statistic is deprecated." "The Statistic objects should implement `get_data` and " @@ -359,7 +403,12 @@ def compute( failure_message="update() must be called before compute_chisq()", ) def compute_chisq(self, tools: ModelingTools) -> float: - """Calculate and return the chi-squared for the given cosmology.""" + """Calculate and return the chi-squared for the given cosmology. + + :param tools: the ModelingTools to be used in the calculation of the + theory vector + :return: the chi-squared + """ theory_vector: npt.NDArray[np.float64] data_vector: npt.NDArray[np.float64] residuals: npt.NDArray[np.float64] @@ -386,6 +435,10 @@ def get_sacc_indices( """Get the SACC indices of the statistic or list of statistics. If no statistic is given, get the indices of all statistics of the likelihood. + + :param statistics: The statistic or list of statistics for which the + SACC indices are desired + :return: The SACC indices """ if statistic is None: statistic = [stat.statistic for stat in self.statistics] @@ -409,7 +462,15 @@ def get_sacc_indices( def make_realization( self, sacc_data: sacc.Sacc, add_noise: bool = True, strict: bool = True ) -> sacc.Sacc: - """Create a new realization of the model.""" + """Create a new realization of the model. + + :param sacc_data: The SACC data object containing the covariance matrix + to be read + :param add_noise: If True, add noise to the realization. + :param strict: If True, check that the indices of the realization cover + all the indices of the SACC data object. + :return: The SACC data object containing the new realization + """ sacc_indices = self.get_sacc_indices() if add_noise: diff --git a/firecrown/likelihood/gaussian.py b/firecrown/likelihood/gaussian.py index f503e6b6..dddd044a 100644 --- a/firecrown/likelihood/gaussian.py +++ b/firecrown/likelihood/gaussian.py @@ -13,12 +13,19 @@ class ConstGaussian(GaussFamily): """A Gaussian log-likelihood with a constant covariance matrix.""" - def compute_loglike(self, tools: ModelingTools): - """Compute the log-likelihood.""" + def compute_loglike(self, tools: ModelingTools) -> float: + """Compute the log-likelihood. + + :params tools: The modeling tools used to compute the likelihood. + :return: The log-likelihood. + """ return -0.5 * self.compute_chisq(tools) def make_realization_vector(self) -> np.ndarray: - """Create a new realization of the model.""" + """Create a new (randomized) realization of the model. + + :return: A new realization of the model + """ theory_vector = self.get_theory_vector() assert self.cholesky is not None new_data_vector = theory_vector + np.dot( diff --git a/firecrown/likelihood/likelihood.py b/firecrown/likelihood/likelihood.py index 9ace48cd..2aa75155 100644 --- a/firecrown/likelihood/likelihood.py +++ b/firecrown/likelihood/likelihood.py @@ -20,7 +20,7 @@ a *data vector* and a *covariance matrix*, which must be present in the :class:`Sacc` object given to the :meth:`read` method. -The theory predictions that are used in the calcluation of a likelihood are +The theory predictions that are used in the calculation of a likelihood are expected to change for different calls to the :meth:`compute_loglike` method. In order to prepare a :class:`Likelihood` object for each call to :meth:`compute_loglike`, the following sequence of calls must be made (note @@ -75,24 +75,32 @@ class Likelihood(Updatable): Concrete subclasses represent specific likelihood forms (e.g. gaussian with constant covariance matrix, or Student's t, etc.). - Concrete subclasses must have an implementation of both *read* and - *compute_loglike*. Note that abstract subclasses of Likelihood might implement + Concrete subclasses must have an implementation of both :meth:`read` and + :meth:`compute_loglike`. Note that abstract subclasses of Likelihood might implement these methods, and provide other abstract methods for their subclasses to implement. """ def __init__(self, parameter_prefix: None | str = None) -> None: - """Default initialization for a base Likelihood object.""" + """Default initialization for a base Likelihood object. + + :params parameter_prefix: The prefix to prepend to all parameter names + """ super().__init__(parameter_prefix=parameter_prefix) @abstractmethod def read(self, sacc_data: sacc.Sacc) -> None: - """Read the covariance matrix for this likelihood from the SACC file.""" + """Read the covariance matrix for this likelihood from the SACC file. + + :param sacc_data: The SACC data object to be read + """ def make_realization_vector(self) -> npt.NDArray[np.float64]: """Create a new realization of the model. This new realization uses the previously computed theory vector and covariance matrix. + + :return: the new realization of the theory vector """ raise NotImplementedError( "This class does not implement make_realization_vector." @@ -111,11 +119,17 @@ def make_realization( only the theory vector. :param strict: If True, check that the indices of the realization cover all the indices of the SACC data object. + + :return: the new SACC object containing the new realization """ @abstractmethod def compute_loglike(self, tools: ModelingTools) -> float: - """Compute the log-likelihood of generic CCL data.""" + """Compute the log-likelihood of generic CCL data. + + :param tools: the ModelingTools to be used in calculating the likelihood + :return: the log-likelihood + """ class NamedParameters: @@ -142,14 +156,22 @@ def __init__( ] ) = None, ): - """Initialize the object from the supplied mapping of values.""" + """Initialize the object from the supplied mapping of values. + + :param mapping: the mapping from strings to values used for initialization + """ if mapping is None: self.data = {} else: self.data = dict(mapping) def get_bool(self, name: str, default_value: None | bool = None) -> bool: - """Return the named parameter as a bool.""" + """Return the named parameter as a bool. + + :param name: the name of the parameter to be returned + :param default_value: the default value if the parameter is not found + :return: the value of the parameter (or the default value) + """ if default_value is None: val = self.data[name] else: @@ -159,7 +181,12 @@ def get_bool(self, name: str, default_value: None | bool = None) -> bool: return val def get_string(self, name: str, default_value: None | str = None) -> str: - """Return the named parameter as a string.""" + """Return the named parameter as a string. + + :param name: the name of the parameter to be returned + :param default_value: the default value if the parameter is not found + :return: the value of the parameter (or the default value) + """ if default_value is None: val = self.data[name] else: @@ -169,7 +196,12 @@ def get_string(self, name: str, default_value: None | str = None) -> str: return val def get_int(self, name: str, default_value: None | int = None) -> int: - """Return the named parameter as an int.""" + """Return the named parameter as an int. + + :param name: the name of the parameter to be returned + :param default_value: the default value if the parameter is not found + :return: the value of the parameter (or the default value) + """ if default_value is None: val = self.data[name] else: @@ -179,7 +211,12 @@ def get_int(self, name: str, default_value: None | int = None) -> int: return val def get_float(self, name: str, default_value: None | float = None) -> float: - """Return the named parameter as a float.""" + """Return the named parameter as a float. + + :param name: the name of the parameter to be returned + :param default_value: the default value if the parameter is not found + :return: the value of the parameter (or the default value) + """ if default_value is None: val = self.data[name] else: @@ -189,7 +226,11 @@ def get_float(self, name: str, default_value: None | float = None) -> float: return val def get_int_array(self, name: str) -> npt.NDArray[np.int64]: - """Return the named parameter as a numpy array of int.""" + """Return the named parameter as a numpy array of int. + + :param name: the name of the parameter to be returned + :return: the value of the parameter + """ tmp = self.data[name] assert isinstance(tmp, np.ndarray) val = tmp.view(dtype=np.int64) @@ -197,7 +238,11 @@ def get_int_array(self, name: str) -> npt.NDArray[np.int64]: return val def get_float_array(self, name: str) -> npt.NDArray[np.float64]: - """Return the named parameter as a numpy array of float.""" + """Return the named parameter as a numpy array of float. + + :param name: the name of the parameter to be returned + :return: the value of the parameter + """ tmp = self.data[name] assert isinstance(tmp, np.ndarray) val = tmp.view(dtype=np.float64) @@ -209,7 +254,10 @@ def to_set( ) -> set[ str | int | bool | float | npt.NDArray[np.int64] | npt.NDArray[np.float64] ]: - """Return the contained data as a set.""" + """Return the contained data as a set. + + :return: the value of the parameter as a set + """ return set(self.data) def set_from_basic_dict( @@ -219,7 +267,10 @@ def set_from_basic_dict( str | float | int | bool | Sequence[float] | Sequence[int] | Sequence[bool], ], ) -> None: - """Set the contained data from a dictionary of basic types.""" + """Set the contained data from a dictionary of basic types. + + :param basic_dict: the mapping from strings to values used for initialization + """ for key, value in basic_dict.items(): if isinstance(value, (str, float, int, bool)): self.data = dict(self.data, **{key: value}) @@ -245,7 +296,10 @@ def convert_to_basic_dict( str, str | float | int | bool | Sequence[float] | Sequence[int] | Sequence[bool], ]: - """Convert a NamedParameters object to a dictionary of basic types.""" + """Convert a NamedParameters object to a dictionary of built-in types. + + :return: a dictionary containing the parameters as built-in Python types + """ basic_dict: dict[ str, str | float | int | bool | Sequence[float] | Sequence[int] | Sequence[bool], @@ -281,6 +335,7 @@ def load_likelihood_from_module_type( :param module: a loaded module :param build_parameters: a NamedParameters object containing the factory function parameters + :return : a tuple of the likelihood and the modeling tools """ if not hasattr(module, "build_likelihood"): if not hasattr(module, "likelihood"): @@ -336,6 +391,7 @@ def load_likelihood_from_script( :param filename: script filename :param build_parameters: a NamedParameters object containing the factory function parameters + :return : a tuple of the likelihood and the modeling tools """ _, file_extension = os.path.splitext(filename) @@ -384,6 +440,7 @@ def load_likelihood_from_module( :param module: module name :param build_parameters: a NamedParameters object containing the factory function parameters + :return : a tuple of the likelihood and the modeling tools """ try: mod = importlib.import_module(module) @@ -406,6 +463,7 @@ def load_likelihood( :param likelihood_name: script filename or module name :param build_parameters: a NamedParameters object containing the factory function parameters + :return : a tuple of the likelihood and the modeling tools """ try: return load_likelihood_from_script(likelihood_name, build_parameters) diff --git a/firecrown/likelihood/number_counts.py b/firecrown/likelihood/number_counts.py index 48b3a3ae..a3cfa1c5 100644 --- a/firecrown/likelihood/number_counts.py +++ b/firecrown/likelihood/number_counts.py @@ -45,14 +45,23 @@ class NumberCountsArgs(SourceGalaxyArgs): class NumberCountsSystematic(SourceGalaxySystematic[NumberCountsArgs]): """Abstract base class for systematics for Number Counts sources. - Derived classes must implement :python`apply` with the correct signature. + Derived classes must implement :meth:`apply` with the correct signature. """ @abstractmethod def apply( self, tools: ModelingTools, tracer_arg: NumberCountsArgs ) -> NumberCountsArgs: - """Apply method to include systematics in the tracer_arg.""" + """Apply method to include systematics in the tracer_arg. + + This does not modify the supplied tracer_arg; it returns a new + one that has been updated. + + :param tools: the Modeling tools used to update the tracer_arg + :param tracer_arg: the original NumberCountsArgs to which to apply the + systematic + :return: the updated NumberCountsArgs + """ class PhotoZShift(SourceGalaxyPhotoZShift[NumberCountsArgs]): @@ -107,12 +116,10 @@ def apply( ) -> NumberCountsArgs: """Apply a linear bias systematic. - Parameters - ---------- - cosmo : Cosmology - A Cosmology object. - tracer_arg : NumberCountsArgs - The source to which apply the shear bias. + :param tools: the ModelingTools used to update the tracer_arg + :param tracer_arg: a NumberCountsArgs object with values to be updated + + :return: the updated NumberCountsArgs object """ ccl_cosmo = tools.get_ccl_cosmology() pref = ((1.0 + tracer_arg.z) / (1.0 + self.z_piv)) ** self.alphaz @@ -224,10 +231,10 @@ def apply( ) -> NumberCountsArgs: """Apply a magnification bias systematic. - :param tools: a ModelingTools object - :param tracer_arg: a NumberCountsArgs object + :param tools: currently unused, by required by the interface + :param tracer_arg: a NumberCountsArgs object with values to be updated - :return: a NumberCountsArgs object + :return: an updated NumberCountsArgs object """ z_bar = self.z_c + self.z_m * (self.r_lim - 24.0) # The slope of log(n_tot(z,r_lim)) with respect to r_lim @@ -286,7 +293,7 @@ def apply( :param tools: currently unused, but required by interface :param tracer_arg: a NumberCountsArgs object with values to be updated - :return: the updated NumberCountsArgs object + :return: an updated NumberCountsArgs object """ return replace( tracer_arg, @@ -344,7 +351,19 @@ def create_ready( scale: float = 1.0, systematics: None | list[SourceGalaxySystematic[NumberCountsArgs]] = None, ) -> NumberCounts: - """Create a NumberCounts object with the given tracer name and scale.""" + """Create a NumberCounts object with the given tracer name and scale. + + This is the recommended way to create a NumberCounts object. It creates + a fully initialized object. + + :param inferred_zdist: the inferred redshift distribution + :param has_rsd: whether to include RSD in the tracer + :param derived_scale: whether to include a derived parameter for the scale + of the tracer + :param scale: the initial scale of the tracer + :param systematics: a list of systematics to apply to the tracer + :return: a fully initialized NumberCounts object + """ obj = cls( sacc_tracer=inferred_zdist.bin_name, systematics=systematics, @@ -365,15 +384,21 @@ def create_ready( return obj @final - def _update_source(self, params: ParamsMap): + def _update_source(self, params: ParamsMap) -> None: """Perform any updates necessary after the parameters have being updated. This implementation must update all contained Updatable instances. + + :param params: the parameters to be used for the update """ self.systematics.update(params) @final def _get_derived_parameters(self) -> DerivedParameterCollection: + """Return the derived parameters for this source. + + :return: the derived parameters + """ if self.derived_scale: assert self.current_tracer_args is not None derived_scale = DerivedParameter( @@ -387,13 +412,10 @@ def _get_derived_parameters(self) -> DerivedParameterCollection: return derived_parameters - def _read(self, sacc_data): + def _read(self, sacc_data) -> None: """Read the data for this source from the SACC file. - Parameters - ---------- - sacc_data : sacc.Sacc - The data in the sacc format. + :param sacc_data: The data in the sacc format. """ # pylint: disable=unexpected-keyword-arg self.tracer_args = NumberCountsArgs( @@ -409,7 +431,11 @@ def _read(self, sacc_data): def create_tracers( self, tools: ModelingTools ) -> tuple[list[Tracer], NumberCountsArgs]: - """Create the tracers for this source.""" + """Create the tracers for this source. + + :param tools: the ModelingTools used to create the tracers + :return: a tuple of tracers and the updated tracer_args + """ tracer_args = self.tracer_args tracer_args = replace(tracer_args, bias=self.bias * np.ones_like(tracer_args.z)) @@ -471,8 +497,11 @@ def create_tracers( return tracers, tracer_args - def get_scale(self): - """Return the scale for this source.""" + def get_scale(self) -> float: + """Return the scale for this source. + + :return: the scale for this source. + """ assert self.current_tracer_args return self.current_tracer_args.scale @@ -488,11 +517,18 @@ class PhotoZShiftFactory(BaseModel): ] = "PhotoZShiftFactory" def create(self, bin_name: str) -> PhotoZShift: - """Create a PhotoZShift object with the given tracer name.""" + """Create a PhotoZShift object with the given tracer name for a bin. + + :param bin_name: the name of the bin + :return: the created PhotoZShift object + """ return PhotoZShift(bin_name) def create_global(self) -> PhotoZShift: - """Create a PhotoZShift object with the given tracer name.""" + """Required by the interface, but raises an error. + + PhotoZShift systematics cannot be global. + """ raise ValueError("PhotoZShift cannot be global.") @@ -526,11 +562,15 @@ class PTNonLinearBiasSystematicFactory(BaseModel): ] = "PTNonLinearBiasSystematicFactory" def create(self, bin_name: str) -> PTNonLinearBiasSystematic: - """Create a PTNonLinearBiasSystematic object with the given tracer name.""" + """Create a PTNonLinearBiasSystematic object with the given tracer name. + + :param bin_name: the name of the bin + :return: the created PTNonLinearBiasSystematic object + """ return PTNonLinearBiasSystematic(bin_name) def create_global(self) -> PTNonLinearBiasSystematic: - """Create a PTNonLinearBiasSystematic object with the given tracer name.""" + """Create a global PTNonLinearBiasSystematic object.""" return PTNonLinearBiasSystematic() @@ -545,11 +585,18 @@ class MagnificationBiasSystematicFactory(BaseModel): ] = "MagnificationBiasSystematicFactory" def create(self, bin_name: str) -> MagnificationBiasSystematic: - """Create a MagnificationBiasSystematic object with the given tracer name.""" + """Create a MagnificationBiasSystematic object with the given tracer name. + + :param bin_name: the name of the bin + :return: the created MagnificationBiasSystematic object + """ return MagnificationBiasSystematic(bin_name) def create_global(self) -> MagnificationBiasSystematic: - """Create a MagnificationBiasSystematic object with the given tracer name.""" + """Required by the interface, but raises an error. + + MagnificationBiasSystematic systematics cannot be global. + """ raise ValueError("MagnificationBiasSystematic cannot be global.") @@ -566,14 +613,15 @@ class ConstantMagnificationBiasSystematicFactory(BaseModel): def create(self, bin_name: str) -> ConstantMagnificationBiasSystematic: """Create a ConstantMagnificationBiasSystematic object. - Use the inferred_zdist to create the systematic. + :param bin_name: the name of the bin + :return: the created ConstantMagnificationBiasSystematic object """ return ConstantMagnificationBiasSystematic(bin_name) def create_global(self) -> ConstantMagnificationBiasSystematic: - """Create a ConstantMagnificationBiasSystematic object. + """Required by the interface, but raises an error. - Use the inferred_zdist to create the systematic. + ConstantMagnificationBiasSystematic systematics cannot be global. """ raise ValueError("ConstantMagnificationBiasSystematic cannot be global.") @@ -599,8 +647,11 @@ class NumberCountsFactory(BaseModel): per_bin_systematics: Sequence[NumberCountsSystematicFactory] global_systematics: Sequence[NumberCountsSystematicFactory] - def model_post_init(self, __context) -> None: - """Initialize the NumberCountsFactory.""" + def model_post_init(self, _) -> None: + """Initialize the NumberCountsFactory. + + :param _: required by the interface but not used + """ self._cache: dict[int, NumberCounts] = {} self._global_systematics_instances = [ nc_systematic_factory.create_global() @@ -608,7 +659,11 @@ def model_post_init(self, __context) -> None: ] def create(self, inferred_zdist: InferredGalaxyZDist) -> NumberCounts: - """Create a NumberCounts object with the given tracer name and scale.""" + """Create a NumberCounts object with the given tracer name and scale. + + :param inferred_zdist: the inferred redshift distribution + :return: a fully initialized NumberCounts object + """ inferred_zdist_id = id(inferred_zdist) if inferred_zdist_id in self._cache: return self._cache[inferred_zdist_id] @@ -628,7 +683,11 @@ def create_from_metadata_only( self, sacc_tracer: str, ) -> NumberCounts: - """Create an WeakLensing object with the given tracer name and scale.""" + """Create an WeakLensing object with the given tracer name and scale. + + :param sacc_tracer: the name of the tracer + :return: a fully initialized NumberCounts object + """ sacc_tracer_id = hash(sacc_tracer) # Improve this if sacc_tracer_id in self._cache: return self._cache[sacc_tracer_id] diff --git a/firecrown/likelihood/supernova.py b/firecrown/likelihood/supernova.py index 40ee63e3..278d05f1 100644 --- a/firecrown/likelihood/supernova.py +++ b/firecrown/likelihood/supernova.py @@ -23,7 +23,7 @@ class Supernova(Statistic): - """A supernova statistic. + """A statistic representing the distance modulus for a single supernova. This statistic that applies an additive shift M to a supernova's distance modulus. diff --git a/firecrown/likelihood/two_point.py b/firecrown/likelihood/two_point.py index 0f1becbd..bfc77a70 100644 --- a/firecrown/likelihood/two_point.py +++ b/firecrown/likelihood/two_point.py @@ -289,7 +289,10 @@ def use_source_factory_metadata_only( class TwoPoint(Statistic): - """A two-point statistic. + """A statistic that represents the correlation between two measurements. + + If the same source is used twice in the same TwoPoint object, this produces + an autocorrelation. For example, shear correlation function, galaxy-shear correlation function, etc. diff --git a/tests/likelihood/gauss_family/statistic/test_statistic.py b/tests/likelihood/gauss_family/statistic/test_statistic.py index 0738ab30..21741003 100644 --- a/tests/likelihood/gauss_family/statistic/test_statistic.py +++ b/tests/likelihood/gauss_family/statistic/test_statistic.py @@ -21,7 +21,7 @@ def test_vector_create(): assert vals.dtype == np.float64 assert vals.shape == (10,) for cls in VECTOR_CLASSES: - result = cls.create(vals) # type: ignore + result = cls.create(vals) assert isinstance(result, cls) assert result.shape == (10,) assert np.array_equal(vals, result) @@ -32,7 +32,7 @@ def test_vector_from_list(): assert isinstance(vals, list) assert len(vals) == 4 for cls in VECTOR_CLASSES: - result = cls.from_list(vals) # type: ignore + result = cls.from_list(vals) assert isinstance(result, cls) assert result.shape == (4,) for i, val in enumerate(vals): @@ -41,7 +41,7 @@ def test_vector_from_list(): def test_vector_slicing(): for cls in VECTOR_CLASSES: - vec = cls.create(np.random.random_sample((12,))) # type: ignore + vec = cls.create(np.random.random_sample((12,))) assert isinstance(vec, cls) middle_part = vec[3:6] assert middle_part.shape == (3,) @@ -50,7 +50,7 @@ def test_vector_slicing(): def test_vector_copying(): for cls in VECTOR_CLASSES: - vec = cls.create(np.random.random_sample((12,))) # type: ignore + vec = cls.create(np.random.random_sample((12,))) assert isinstance(vec, cls) vec_copy = vec.copy() assert vec_copy is not vec @@ -70,7 +70,7 @@ def test_ufunc_on_vector(): data = np.array([0.0, 0.25, 0.50]) expected = np.sin(data) for cls in VECTOR_CLASSES: - vec = cls.create(data) # type: ignore + vec = cls.create(data) result = np.sin(vec) assert isinstance(result, cls) assert np.array_equal(result, expected) From 5b2828441cec794c95f1ffa1717e6e584a9fb25d Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Tue, 20 Aug 2024 12:36:15 -0500 Subject: [PATCH 05/10] Again more docstrings (#447) * Update docstrings and type annotations --- firecrown/likelihood/number_counts.py | 2 +- firecrown/likelihood/source.py | 107 +++++++++++++++------ firecrown/likelihood/statistic.py | 133 +++++++++++++++++++++----- firecrown/likelihood/student_t.py | 13 ++- firecrown/likelihood/supernova.py | 22 ++++- 5 files changed, 214 insertions(+), 63 deletions(-) diff --git a/firecrown/likelihood/number_counts.py b/firecrown/likelihood/number_counts.py index a3cfa1c5..7fa73b47 100644 --- a/firecrown/likelihood/number_counts.py +++ b/firecrown/likelihood/number_counts.py @@ -231,7 +231,7 @@ def apply( ) -> NumberCountsArgs: """Apply a magnification bias systematic. - :param tools: currently unused, by required by the interface + :param tools: currently unused, but required by the interface :param tracer_arg: a NumberCountsArgs object with values to be updated :return: an updated NumberCountsArgs object diff --git a/firecrown/likelihood/source.py b/firecrown/likelihood/source.py index 8b6c491d..7ddafc28 100644 --- a/firecrown/likelihood/source.py +++ b/firecrown/likelihood/source.py @@ -29,19 +29,15 @@ class SourceSystematic(Updatable): the `apply` method of different subclasses are different. """ - def read(self, sacc_data: sacc.Sacc): - """Call to allow this object to read from the appropriate sacc data.""" + def read(self, sacc_data: sacc.Sacc) -> None: + """Call to allow this object to read from the appropriate sacc data. + :param sacc_data: The SACC data object to be read + """ -class Source(Updatable): - """An abstract source class (e.g., a sample of lenses). - Parameters - ---------- - systematics : list of str, optional - A list of the source-level systematics to apply to the source. The - default of `None` implies no systematics. - """ +class Source(Updatable): + """The abstract base class for all sources.""" systematics: Sequence[SourceSystematic] cosmo_hash: None | int @@ -57,22 +53,30 @@ def __init__(self, sacc_tracer: str) -> None: self.sacc_tracer = sacc_tracer @final - def read(self, sacc_data: sacc.Sacc): - """Read the data for this source from the SACC file.""" + def read(self, sacc_data: sacc.Sacc) -> None: + """Read the data for this source from the SACC file. + + :param sacc_data: The SACC data object to be read + """ if hasattr(self, "systematics"): for systematic in self.systematics: systematic.read(sacc_data) self._read(sacc_data) @abstractmethod - def _read(self, sacc_data: sacc.Sacc): - """Abstract method to read the data for this source from the SACC file.""" + def _read(self, sacc_data: sacc.Sacc) -> None: + """Abstract method to read the data for this source from the SACC file. + + :param sacc_data: The SACC data object to be read + """ - def _update_source(self, params: ParamsMap): + def _update_source(self, params: ParamsMap) -> None: """Method to update the source from the given ParamsMap. Any subclass that needs to do more than update its contained :class:`Updatable` objects should implement this method. + + :param params: the parameters to be used for the update """ @final @@ -81,6 +85,8 @@ def _update(self, params: ParamsMap): This clears the current hash and tracer, and calls the abstract method `_update_source`, which must be implemented in all subclasses. + + :param params: the parameters to be used for the update """ self.cosmo_hash = None self.tracers = [] @@ -88,11 +94,17 @@ def _update(self, params: ParamsMap): @abstractmethod def get_scale(self) -> float: - """Abstract method to return the scales for this `Source`.""" + """Abstract method to return the scale for this `Source`. + + :return: the scale + """ @abstractmethod def create_tracers(self, tools: ModelingTools): - """Create tracers for this `Source`, for the given cosmology.""" + """Abstract method to create tracers for this Source. + + :param tools: The modeling tools used for creating the tracers + """ @final def get_tracers(self, tools: ModelingTools) -> Sequence[Tracer]: @@ -100,6 +112,9 @@ def get_tracers(self, tools: ModelingTools) -> Sequence[Tracer]: This method caches its result, so if called a second time with the same cosmology, no calculation needs to be done. + + :param tools: The modeling tools used for creating the tracers + :return: the list of tracers """ ccl_cosmo = tools.get_ccl_cosmology() @@ -128,6 +143,10 @@ def determine_field_name(field: None | str, tracer: None | str) -> str: It is a static method only to keep it grouped with the class for which it is defining the initialization policy. + + :param field: the (stub) name of the field + :param tracer: the name of the tracer + :return: the full name of the field """ if field is not None: return field @@ -144,7 +163,7 @@ def __init__( halo_profile: None | pyccl.halos.HaloProfile = None, halo_2pt: None | pyccl.halos.Profile2pt = None, ): - """Initialize a new Tracer based on the pyccl.Tracer which must not be None. + """Initialize a new Tracer based on the provided tracer. Note that the :class:`pyccl.Tracer` is not copied; we store a reference to the original tracer. Be careful not to accidentally share :class:`pyccl.Tracer`s. @@ -154,6 +173,13 @@ def __init__( If no `field` is given, then the attribute :attr:`field` is set to either (1) the tracer_name, if one was given, or (2) 'delta_matter'. + + :param tracer: the pyccl.Tracer used as the basis for this Tracer. + :param tracer_name: optional name of the tracer. + :param field: optional name of the field associated with the tracer. + :param pt_tracer: optional non-linear perturbation theory tracer. + :param halo_profile: optional halo profile. + :param halo_2pt: optional halo profile 2-point object. """ assert tracer is not None self.ccl_tracer = tracer @@ -165,12 +191,18 @@ def __init__( @property def has_pt(self) -> bool: - """Return True if we have a pt_tracer, and False if not.""" + """Answer whether we have a perturbation theory tracer. + + :return: True if we have a pt_tracer, and False if not. + """ return self.pt_tracer is not None @property def has_hm(self) -> bool: - """Return True if we have a halo_profile, and False if not.""" + """Answer whether we have a halo profile. + + :return: True if we have a halo_profile, and False if not. + """ return self.halo_profile is not None @@ -193,13 +225,19 @@ class SourceGalaxyArgs: class SourceGalaxySystematic(SourceSystematic, Generic[_SourceGalaxyArgsT]): - """Abstract base class for all galaxy based source systematics.""" + """Abstract base class for all galaxy-based source systematics.""" @abstractmethod def apply( self, tools: ModelingTools, tracer_arg: _SourceGalaxyArgsT ) -> _SourceGalaxyArgsT: - """Apply method to include systematics in the tracer_arg.""" + """Apply method to include systematics in the tracer_arg. + + :param tools: the modeling tools use to update the tracer arg + :param tracer_arg: the original source galaxy tracer arg to which we + apply the systematic. + :return: a new source galaxy tracer arg with the systematic applied + """ _SourceGalaxySystematicT = TypeVar( @@ -236,8 +274,16 @@ def __init__(self, sacc_tracer: str) -> None: default_value=SOURCE_GALAXY_SYSTEMATIC_DEFAULT_DELTA_Z ) - def apply(self, tools: ModelingTools, tracer_arg: _SourceGalaxyArgsT): - """Apply a shift to the photo-z distribution of a source.""" + def apply( + self, tools: ModelingTools, tracer_arg: _SourceGalaxyArgsT + ) -> _SourceGalaxyArgsT: + """Apply a shift to the photo-z distribution of a source. + + :param tools: the modeling tools use to update the tracer arg + :param tracer_arg: the original source galaxy tracer arg to which we + apply the systematic. + :return: a new source galaxy tracer arg with the systematic applied + """ dndz_interp = Akima1DInterpolator(tracer_arg.z, tracer_arg.dndz) dndz = dndz_interp(tracer_arg.z - self.delta_z, extrapolate=False) @@ -263,7 +309,6 @@ def __init__(self, field: str = "delta_matter"): """Specify which 3D field should be used when computing angular power spectra. :param field: the name of the 3D field that is associated to the tracer. - Default: `"delta_matter"` """ super().__init__() self.field = field @@ -271,7 +316,13 @@ def __init__(self, field: str = "delta_matter"): def apply( self, tools: ModelingTools, tracer_arg: _SourceGalaxyArgsT ) -> _SourceGalaxyArgsT: - """Apply method to include systematics in the tracer_arg.""" + """Apply method to include systematics in the tracer_arg. + + :param tools: the modeling tools used to update the tracer_arg + :param tracer_arg: the original source galaxy tracer arg to which we + apply the systematics. + :return: a new source galaxy tracer arg with the systematic applied + """ return replace(tracer_arg, field=self.field) @@ -299,11 +350,13 @@ def __init__( ) self.tracer_args: _SourceGalaxyArgsT - def _read(self, sacc_data: sacc.Sacc): + def _read(self, sacc_data: sacc.Sacc) -> None: """Read the galaxy redshift distribution model from a sacc file. All derived classes must call this method in their own `_read` method after they have read their own data and initialized their tracer_args. + + :param sacc_data: The SACC data object to be read """ try: tracer_args = self.tracer_args diff --git a/firecrown/likelihood/statistic.py b/firecrown/likelihood/statistic.py index e951ff55..36e3b243 100644 --- a/firecrown/likelihood/statistic.py +++ b/firecrown/likelihood/statistic.py @@ -5,7 +5,7 @@ """ from __future__ import annotations -from typing import final +from typing import final, Iterator from dataclasses import dataclass from abc import abstractmethod import warnings @@ -20,31 +20,47 @@ class DataVector(npt.NDArray[np.float64]): - """This class wraps a np.ndarray that represents some observed data values.""" + """Wrapper for a np.ndarray that represents some observed data values.""" @classmethod def create(cls, vals: npt.NDArray[np.float64]) -> DataVector: - """Create a DataVector that wraps a copy of the given array vals.""" + """Create a DataVector that wraps a copy of the given array vals. + + :param vals: the array to be copied and wrapped + :return: a new DataVector + """ return vals.view(cls) @classmethod def from_list(cls, vals: list[float]) -> DataVector: - """Create a DataVector from the given list of floats.""" + """Create a DataVector from the given list of floats. + + :param vals: the list of floats + :return: a new DataVector + """ array = np.array(vals) return cls.create(array) class TheoryVector(npt.NDArray[np.float64]): - """This class represents an observation predicted by some theory.""" + """Wrapper for an np.ndarray that represents a prediction by some theory.""" @classmethod def create(cls, vals: npt.NDArray[np.float64]) -> TheoryVector: - """Create a TheoryVector that wraps a copy of the given array vals.""" + """Create a TheoryVector that wraps a copy of the given array vals. + + :param vals: the array to be copied and wrapped + :return: a new TheoryVector + """ return vals.view(cls) @classmethod def from_list(cls, vals: list[float]) -> TheoryVector: - """Create a TheoryVector from the given list of floats.""" + """Create a TheoryVector from the given list of floats. + + :param vals: the list of floats + :return: a new TheoryVector + """ array = np.array(vals) return cls.create(array) @@ -61,20 +77,26 @@ def residuals(data: DataVector, theory: TheoryVector) -> npt.NDArray[np.float64] @dataclass class StatisticsResult: - """This is the type returned by the `compute` method of any `Statistic`.""" + """An pair of a :python:`DataVector` and a :python:`TheoryVector`. + + This is the type returned by the :meth:`compute` method of any :python:`Statistic`. + """ data: DataVector theory: TheoryVector - def __post_init__(self): + def __post_init__(self) -> None: """Make sure the data and theory vectors are of the same shape.""" assert self.data.shape == self.theory.shape def residuals(self) -> npt.NDArray[np.float64]: - """Return the residuals -- the difference between data and theory.""" + """Return the residuals -- the difference between data and theory. + + :return: the residuals + """ return self.data - self.theory - def __iter__(self): + def __iter__(self) -> Iterator[DataVector | TheoryVector]: """Iterate through the data members. This is to allow automatic unpacking, as if the StatisticsResult were a tuple @@ -82,6 +104,8 @@ def __iter__(self): This method is a temporary measure to help code migrate to the newer, safer interface for Statistic.compute(). + + :return: an iterator object that yields first the data and then the theory """ warnings.warn( "Iteration and tuple unpacking for StatisticsResult is " @@ -100,6 +124,10 @@ class StatisticUnreadError(RuntimeError): """ def __init__(self, stat: Statistic): + """Initialize a new StatisticUnreadError. + + :param stat: the statistic that was accessed before `read` was called + """ msg = ( f"The statistic {stat} was used for calculation before `read` " f"was called.\nIt may be that a likelihood factory function did not" @@ -120,6 +148,18 @@ class Statistic(Updatable): """ def __init__(self, parameter_prefix: None | str = None): + """Initialize a new Statistic. + + Derived classes should make sure to class this method using: + + .. code-block:: python + + super().__init__(parameter_prefix=parameter_prefix) + + as the last thing they do in `__init__`. + + :param parameter_prefix: The prefix to prepend to all parameter names + """ super().__init__(parameter_prefix=parameter_prefix) self.sacc_indices: None | npt.NDArray[np.int64] self.ready = False @@ -129,17 +169,16 @@ def __init__(self, parameter_prefix: None | str = None): def read(self, _: sacc.Sacc) -> None: """Read the data for this statistic and mark it as ready for use. - Derived classes that override this function should make sure to call the base - class method using: + Derived classes that override this function should make sure to call the + base class method using: .. code-block:: python super().read(sacc_data) - as the last thing they do in `__init__`. + as the last thing they do. - Note that currently the argument is not used; it is present so that this - method will have the correct argument type for the override. + :param _: currently unused, but required by the interface. """ self.ready = True if len(self.get_data_vector()) == 0: @@ -151,18 +190,32 @@ class method using: def _reset(self): """Reset this statistic. - All subclasses implementations must call super()._reset() + Derived classes that override this function should make sure to call the + base class method using: + + .. code-block:: python + + super()._reset() + + as the last thing they do. """ self.computed_theory_vector = False self.theory_vector = None @abstractmethod def get_data_vector(self) -> DataVector: - """Gets the statistic data vector.""" + """Gets the statistic data vector. + + :return: The data vector. + """ @final def compute_theory_vector(self, tools: ModelingTools) -> TheoryVector: - """Compute a statistic from sources, applying any systematics.""" + """Compute a statistic from sources, applying any systematics. + + :param tools: the modeling tools used to compute the theory vector. + :return: The computed theory vector. + """ if not self.is_updated(): raise RuntimeError( f"The statistic {self} has not been updated with parameters." @@ -180,6 +233,8 @@ def get_theory_vector(self) -> TheoryVector: """Returns the last computed theory vector. Raises a RuntimeError if the vector has not been computed. + + :return: The already-computed theory vector. """ if not self.computed_theory_vector: raise RuntimeError( @@ -200,7 +255,10 @@ class GuardedStatistic(Updatable): """ def __init__(self, stat: Statistic): - """Initialize the GuardedStatistic to contain the given :class:`Statistic`.""" + """Initialize the GuardedStatistic to contain the given :class:`Statistic`. + + :param stat: The statistic to wrap. + """ super().__init__() assert isinstance(stat, Statistic) self.statistic = stat @@ -211,6 +269,8 @@ def read(self, sacc_data: sacc.Sacc) -> None: After this function is called, the object should be prepared for the calling of the methods :meth:`get_data_vector` and :meth:`compute_theory_vector`. + + :param sacc_data: The SACC data object to read from. """ if self.statistic.ready: raise RuntimeError("Firecrown has called read twice on a GuardedStatistic") @@ -229,6 +289,8 @@ def get_data_vector(self) -> DataVector: :class:`GuardedStatistic` ensures that :meth:`read` has been called. first. + + :return: The most recently calculated data vector. """ if not self.statistic.ready: raise StatisticUnreadError(self.statistic) @@ -239,6 +301,9 @@ def compute_theory_vector(self, tools: ModelingTools) -> TheoryVector: :class:`GuardedStatistic` ensures that :meth:`read` has been called. first. + + :param tools: the modeling tools used to compute the theory vector. + :return: The computed theory vector. """ if not self.statistic.ready: raise StatisticUnreadError(self.statistic) @@ -264,8 +329,11 @@ def __init__(self) -> None: ) self.computed_theory_vector = False - def read(self, sacc_data: sacc.Sacc): - """Read the necessary items from the sacc data.""" + def read(self, sacc_data: sacc.Sacc) -> None: + """Read the necessary items from the sacc data. + + :param sacc_data: The SACC data object to be read + """ our_data = sacc_data.get_mean(data_type="count") assert len(our_data) == self.count self.data_vector = DataVector.from_list(our_data) @@ -274,19 +342,32 @@ def read(self, sacc_data: sacc.Sacc): @final def _required_parameters(self) -> RequiredParameters: - """Return an empty RequiredParameters.""" + """Return an empty RequiredParameters. + + :return: an empty RequiredParameters. + """ return RequiredParameters([]) @final def _get_derived_parameters(self) -> DerivedParameterCollection: - """Return an empty DerivedParameterCollection.""" + """Return an empty DerivedParameterCollection. + + :return: an empty DerivedParameterCollection. + """ return DerivedParameterCollection([]) def get_data_vector(self) -> DataVector: - """Return the data vector; raise exception if there is none.""" + """Return the data vector; raise exception if there is none. + + :return: The data vector. + """ assert self.data_vector is not None return self.data_vector def _compute_theory_vector(self, _: ModelingTools) -> TheoryVector: - """Return a fixed theory vector.""" + """Return a fixed theory vector. + + :param _: unused, but required by the interface + :return: A fixed theory vector + """ return TheoryVector.from_list([self.mean] * self.count) diff --git a/firecrown/likelihood/student_t.py b/firecrown/likelihood/student_t.py index 09c1e1cb..6d69a15e 100644 --- a/firecrown/likelihood/student_t.py +++ b/firecrown/likelihood/student_t.py @@ -20,9 +20,6 @@ class StudentT(GaussFamily): from a finite number of simulations. See Sellentin & Heavens (2016; arXiv:1511.05969). As the number of simulations increases, the T-distribution approaches a Gaussian. - - :param statistics: list of statistics to build the theory and data vectors - :param nu: The Student-t $\nu$ parameter """ def __init__( @@ -30,13 +27,19 @@ def __init__( statistics: list[Statistic], nu: None | float = None, ): + """Initialize a StudentT object. + + :param statistics: The statistics to use in the likelihood + :param nu: The degrees of freedom of the T-distribution + """ super().__init__(statistics) self.nu = parameters.register_new_updatable_parameter(nu, default_value=3.0) - def compute_loglike(self, tools: ModelingTools): + def compute_loglike(self, tools: ModelingTools) -> float: """Compute the log-likelihood. - :param cosmo: Current Cosmology object + :param tools: The modeling tools used to compute the likelihood. + :return: The log-likelihood. """ chi2 = self.compute_chisq(tools) return -0.5 * self.nu * np.log(1.0 + chi2 / (self.nu - 1.0)) diff --git a/firecrown/likelihood/supernova.py b/firecrown/likelihood/supernova.py index 278d05f1..5dccec64 100644 --- a/firecrown/likelihood/supernova.py +++ b/firecrown/likelihood/supernova.py @@ -30,7 +30,11 @@ class Supernova(Statistic): """ def __init__(self, sacc_tracer: str) -> None: - """Initialize this statistic.""" + """Initialize a Supernova object. + + :param sacc_tracer: the name of the tracer in the SACC file. This is used + as a prefix for its parameters. + """ super().__init__(parameter_prefix=sacc_tracer) self.sacc_tracer = sacc_tracer @@ -41,7 +45,10 @@ def __init__(self, sacc_tracer: str) -> None: ) def read(self, sacc_data: sacc.Sacc) -> None: - """Read the data for this statistic from the SACC file.""" + """Read the data for this statistic from the SACC file. + + :param sacc_data: The data in the sacc format. + """ # We do not actually need the tracer, but we want to make sure the SACC # data contains this tracer. # TODO: remove the work-around when the new version of SACC supporting @@ -68,12 +75,19 @@ def read(self, sacc_data: sacc.Sacc) -> None: super().read(sacc_data) def get_data_vector(self) -> DataVector: - """Return the data vector; raise exception if there is none.""" + """Return the data vector; raise exception if there is none. + + :return: The data vector. + """ assert self.data_vector is not None return self.data_vector def _compute_theory_vector(self, tools: ModelingTools) -> TheoryVector: - """Compute SNIa distance statistic using CCL.""" + """Compute SNIa distance statistic using CCL. + + :param tools: the modeling tools used to compute the theory vector. + :return: The computed theory vector. + """ ccl_cosmo = tools.get_ccl_cosmology() prediction = self.M + pyccl.distance_modulus(ccl_cosmo, self.a) return TheoryVector.create(prediction) From 4eef092d29e8fbd0aa08c11691435e31f13d85c2 Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Tue, 10 Sep 2024 15:22:59 -0300 Subject: [PATCH 06/10] Restructuring metadata objects (#448) * First move of files, and adjustments * Remove of Window and related classes * Merge of TwoPointCell and TwoPointCWindow into TwoPointHarmonic. * Updated tests for the new object structure. * Finished inversion of data and metadata. * Adding tests for uncovered lines. * Updating tutorials to use new structure. * Testing window function application. * Removing left-over arrays from TwoPoint*Index dataclasses. * Putting ids on parametrized tests. --------- Co-authored-by: Marc Paterno --- firecrown/data_functions.py | 200 +++++++++ firecrown/data_types.py | 62 +++ firecrown/generators/inferred_galaxy_zdist.py | 6 +- firecrown/likelihood/gaussfamily.py | 2 +- firecrown/likelihood/number_counts.py | 2 +- firecrown/likelihood/two_point.py | 393 +++++++++--------- firecrown/likelihood/weak_lensing.py | 2 +- firecrown/metadata/__init__.py | 5 - .../two_point.py => metadata_functions.py} | 385 ++++++----------- .../two_point_types.py => metadata_types.py} | 206 ++------- tests/conftest.py | 71 ++-- .../generators/test_inferred_galaxy_zdist.py | 2 +- .../statistic/source/test_source.py | 18 +- .../gauss_family/statistic/test_two_point.py | 163 +++++++- .../gauss_family/test_const_gaussian.py | 27 +- tests/metadata/sacc_name_mapping.py | 2 +- tests/metadata/test_metadata_two_point.py | 356 +++++++--------- .../test_metadata_two_point_measurement.py | 212 ++++++---- .../metadata/test_metadata_two_point_sacc.py | 351 +++++++++------- .../metadata/test_metadata_two_point_types.py | 230 +++++----- tutorial/inferred_zdist.qmd | 14 +- tutorial/inferred_zdist_generators.qmd | 4 +- tutorial/inferred_zdist_serialization.qmd | 2 +- tutorial/two_point_factories.qmd | 28 +- tutorial/two_point_generators.qmd | 10 +- 25 files changed, 1464 insertions(+), 1289 deletions(-) create mode 100644 firecrown/data_functions.py create mode 100644 firecrown/data_types.py delete mode 100644 firecrown/metadata/__init__.py rename firecrown/{metadata/two_point.py => metadata_functions.py} (56%) rename firecrown/{metadata/two_point_types.py => metadata_types.py} (74%) diff --git a/firecrown/data_functions.py b/firecrown/data_functions.py new file mode 100644 index 00000000..69ef09b6 --- /dev/null +++ b/firecrown/data_functions.py @@ -0,0 +1,200 @@ +"""This module deals with two-point data functions. + +It contains functions to manipulate two-point data objects. +""" + +import hashlib +from typing import Sequence + +import sacc + +from firecrown.metadata_types import ( + TwoPointHarmonic, + TwoPointReal, +) +from firecrown.metadata_functions import ( + extract_all_tracers_inferred_galaxy_zdists, + extract_window_function, + extract_all_harmonic_metadata_indices, + extract_all_real_metadata_indices, + make_two_point_xy, +) +from firecrown.data_types import TwoPointMeasurement + + +def extract_all_harmonic_data( + sacc_data: sacc.Sacc, + allowed_data_type: None | list[str] = None, + include_maybe_types=False, +) -> list[TwoPointMeasurement]: + """Extract the two-point function metadata and data from a sacc file.""" + inferred_galaxy_zdists_dict = { + igz.bin_name: igz + for igz in extract_all_tracers_inferred_galaxy_zdists( + sacc_data, include_maybe_types=include_maybe_types + ) + } + + if sacc_data.covariance is None or sacc_data.covariance.dense is None: + raise ValueError("The SACC object does not have a covariance matrix.") + cov_hash = hashlib.sha256(sacc_data.covariance.dense).hexdigest() + + tpms: list[TwoPointMeasurement] = [] + for cell_index in extract_all_harmonic_metadata_indices( + sacc_data, allowed_data_type + ): + t1, t2 = cell_index["tracer_names"] + dt = cell_index["data_type"] + + ells, Cells, indices = sacc_data.get_ell_cl( + data_type=dt, tracer1=t1, tracer2=t2, return_cov=False, return_ind=True + ) + + replacement_ells, weights = extract_window_function(sacc_data, indices) + if replacement_ells is not None: + ells = replacement_ells + + tpms.append( + TwoPointMeasurement( + data=Cells, + indices=indices, + covariance_name=cov_hash, + metadata=TwoPointHarmonic( + XY=make_two_point_xy( + inferred_galaxy_zdists_dict, cell_index["tracer_names"], dt + ), + window=weights, + ells=ells, + ), + ), + ) + + return tpms + + +# Extracting the two-point function metadata and data from a sacc file + + +def extract_all_real_data( + sacc_data: sacc.Sacc, + allowed_data_type: None | list[str] = None, + include_maybe_types=False, +) -> list[TwoPointMeasurement]: + """Extract the two-point function metadata and data from a sacc file.""" + inferred_galaxy_zdists_dict = { + igz.bin_name: igz + for igz in extract_all_tracers_inferred_galaxy_zdists( + sacc_data, include_maybe_types=include_maybe_types + ) + } + + cov_hash = hashlib.sha256(sacc_data.covariance.dense).hexdigest() + + result: list[TwoPointMeasurement] = [] + for real_index in extract_all_real_metadata_indices(sacc_data, allowed_data_type): + t1, t2 = real_index["tracer_names"] + dt = real_index["data_type"] + + thetas, Xis, indices = sacc_data.get_theta_xi( + data_type=dt, tracer1=t1, tracer2=t2, return_cov=False, return_ind=True + ) + + result.append( + TwoPointMeasurement( + data=Xis, + indices=indices, + covariance_name=cov_hash, + metadata=TwoPointReal( + XY=make_two_point_xy( + inferred_galaxy_zdists_dict, real_index["tracer_names"], dt + ), + thetas=thetas, + ), + ) + ) + + return result + + +def check_two_point_consistence_harmonic( + two_point_harmonics: Sequence[TwoPointMeasurement], +) -> None: + """Check the indices of the harmonic-space two-point functions. + + Make sure the indices of the harmonic-space two-point functions are consistent. + """ + all_indices_set: set[int] = set() + index_set_list = [] + cov_name: None | str = None + + for harmonic in two_point_harmonics: + if not harmonic.is_harmonic(): + raise ValueError( + f"The metadata of the TwoPointMeasurement {harmonic} is not " + f"a measurement of TwoPointHarmonic." + ) + if cov_name is None: + cov_name = harmonic.covariance_name + elif cov_name != harmonic.covariance_name: + raise ValueError( + f"The TwoPointHarmonic {harmonic} has a different covariance " + f"name {harmonic.covariance_name} than the previous " + f"TwoPointHarmonic {cov_name}." + ) + index_set = set(harmonic.indices) + index_set_list.append(index_set) + if len(index_set) != len(harmonic.indices): + raise ValueError( + f"The indices of the TwoPointHarmonic {harmonic} are not unique." + ) + + if all_indices_set & index_set: + for i, index_set_a in enumerate(index_set_list): + if index_set_a & index_set: + raise ValueError( + f"The indices of the TwoPointHarmonic " + f"{two_point_harmonics[i]} and {harmonic} overlap." + ) + all_indices_set.update(index_set) + + +def check_two_point_consistence_real( + two_point_reals: Sequence[TwoPointMeasurement], +) -> None: + """Check the indices of the real-space two-point functions. + + Make sure the indices of the real-space two-point functions are consistent. + """ + all_indices_set: set[int] = set() + index_set_list = [] + cov_name: None | str = None + + for two_point_real in two_point_reals: + if not two_point_real.is_real(): + raise ValueError( + f"The metadata of the TwoPointMeasurement {two_point_real} is not " + f"a measurement of TwoPointReal." + ) + if cov_name is None: + cov_name = two_point_real.covariance_name + elif cov_name != two_point_real.covariance_name: + raise ValueError( + f"The TwoPointReal {two_point_real} has a different covariance " + f"name {two_point_real.covariance_name} than the previous " + f"TwoPointReal {cov_name}." + ) + index_set = set(two_point_real.indices) + index_set_list.append(index_set) + if len(index_set) != len(two_point_real.indices): + raise ValueError( + f"The indices of the TwoPointReal {two_point_real} " f"are not unique." + ) + + if all_indices_set & index_set: + for i, index_set_a in enumerate(index_set_list): + if index_set_a & index_set: + raise ValueError( + f"The indices of the TwoPointReal {two_point_reals[i]} " + f"and {two_point_real} overlap." + ) + all_indices_set.update(index_set) diff --git a/firecrown/data_types.py b/firecrown/data_types.py new file mode 100644 index 00000000..91453c8c --- /dev/null +++ b/firecrown/data_types.py @@ -0,0 +1,62 @@ +"""This module deals with data types. + +This module contains data types definitions. +""" + +from dataclasses import dataclass + +import numpy as np +import numpy.typing as npt + +from firecrown.utils import YAMLSerializable +from firecrown.metadata_types import TwoPointReal, TwoPointHarmonic + + +@dataclass(frozen=True, kw_only=True) +class TwoPointMeasurement(YAMLSerializable): + """Class defining the data for a two-point measurement. + + The class used to store the data for a two-point function measured on a sphere. + + This includes the measured two-point function, their indices in the covariance + matrix and the name of the covariance matrix. The corresponding metadata is also + stored. + """ + + data: npt.NDArray[np.float64] + indices: npt.NDArray[np.int64] + covariance_name: str + metadata: TwoPointReal | TwoPointHarmonic + + def __post_init__(self) -> None: + """Make sure the data and indices have the same shape.""" + if len(self.data.shape) != 1: + raise ValueError("Data should be a 1D array.") + + if self.data.shape != self.indices.shape: + raise ValueError("Data and indices should have the same shape.") + + if not isinstance(self.metadata, (TwoPointReal, TwoPointHarmonic)): + raise ValueError( + "Metadata should be an instance of TwoPointReal or TwoPointHarmonic." + ) + + if len(self.data) != self.metadata.n_observations(): + raise ValueError("Data and metadata should have the same length.") + + def __eq__(self, other) -> bool: + """Equality test for TwoPointMeasurement objects.""" + return ( + np.array_equal(self.data, other.data) + and np.array_equal(self.indices, other.indices) + and self.covariance_name == other.covariance_name + and self.metadata == other.metadata + ) + + def is_real(self) -> bool: + """Check if the metadata is real.""" + return isinstance(self.metadata, TwoPointReal) + + def is_harmonic(self) -> bool: + """Check if the metadata is harmonic.""" + return isinstance(self.metadata, TwoPointHarmonic) diff --git a/firecrown/generators/inferred_galaxy_zdist.py b/firecrown/generators/inferred_galaxy_zdist.py index 430a2fc3..220972a6 100644 --- a/firecrown/generators/inferred_galaxy_zdist.py +++ b/firecrown/generators/inferred_galaxy_zdist.py @@ -12,17 +12,15 @@ from numcosmo_py import Ncm -from firecrown.metadata.two_point_types import ( +from firecrown.metadata_types import ( InferredGalaxyZDist, ALL_MEASUREMENT_TYPES, -) -from firecrown.metadata.two_point_types import ( make_measurements_dict, - Measurement, Galaxies, CMB, Clusters, ) +from firecrown.metadata_functions import Measurement BinsType = TypedDict("BinsType", {"edges": npt.NDArray, "sigma_z": float}) diff --git a/firecrown/likelihood/gaussfamily.py b/firecrown/likelihood/gaussfamily.py index 26516382..2c5e78a3 100644 --- a/firecrown/likelihood/gaussfamily.py +++ b/firecrown/likelihood/gaussfamily.py @@ -269,7 +269,7 @@ def _set_covariance(self, covariance: npt.NDArray[np.float64]) -> None: data_vector = np.concatenate(data_vector_list) cov = np.zeros((len(indices), len(indices))) - largest_index = np.max(indices) + largest_index = int(np.max(indices)) if not ( covariance.ndim == 2 diff --git a/firecrown/likelihood/number_counts.py b/firecrown/likelihood/number_counts.py index 7fa73b47..fe4b528f 100644 --- a/firecrown/likelihood/number_counts.py +++ b/firecrown/likelihood/number_counts.py @@ -24,7 +24,7 @@ SourceGalaxySystematic, Tracer, ) -from firecrown.metadata.two_point import InferredGalaxyZDist +from firecrown.metadata_types import InferredGalaxyZDist from firecrown.modeling_tools import ModelingTools from firecrown.parameters import DerivedParameter, DerivedParameterCollection, ParamsMap from firecrown.updatable import UpdatableCollection diff --git a/firecrown/likelihood/two_point.py b/firecrown/likelihood/two_point.py index bfc77a70..0e7fe3d3 100644 --- a/firecrown/likelihood/two_point.py +++ b/firecrown/likelihood/two_point.py @@ -31,26 +31,23 @@ Statistic, TheoryVector, ) -from firecrown.metadata.two_point_types import ( - TRACER_NAMES_TOTAL, - InferredGalaxyZDist, +from firecrown.metadata_types import ( Galaxies, + InferredGalaxyZDist, Measurement, + TRACER_NAMES_TOTAL, + TracerNames, + TwoPointHarmonic, + TwoPointReal, ) -from firecrown.metadata.two_point import ( - TracerNames, - TwoPointCells, - TwoPointCWindow, - TwoPointXiTheta, - TwoPointCellsIndex, - TwoPointXiThetaIndex, - Window, +from firecrown.metadata_functions import ( + TwoPointHarmonicIndex, + TwoPointRealIndex, extract_window_function, - check_two_point_consistence_harmonic, - check_two_point_consistence_real, measurements_from_index, ) +from firecrown.data_types import TwoPointMeasurement from firecrown.modeling_tools import ModelingTools from firecrown.updatable import UpdatableCollection @@ -124,7 +121,9 @@ def make_log_interpolator(x, y): return lambda x_, intp=intp: intp(np.log(x_)) -def calculate_ells_for_interpolation(w: Window) -> npt.NDArray[np.int64]: +def calculate_ells_for_interpolation( + min_ell: int, max_ell: int +) -> npt.NDArray[np.int64]: """See _ell_for_xi. This method mixes together: @@ -134,11 +133,9 @@ def calculate_ells_for_interpolation(w: Window) -> npt.NDArray[np.int64]: and then calls _ell_for_xi with those arguments, returning whatever it returns. """ - ell_config = { - **ELL_FOR_XI_DEFAULTS, - "maximum": w.ells[-1], - } - ell_config["minimum"] = max(ell_config["minimum"], w.ells[0]) + ell_config = copy.deepcopy(ELL_FOR_XI_DEFAULTS) + ell_config["maximum"] = max_ell + ell_config["minimum"] = max(ell_config["minimum"], min_ell) return _ell_for_xi(**ell_config) @@ -170,7 +167,7 @@ def generate_ells_cells(ell_config: EllOrThetaConfig): return ells, Cells -def generate_theta_xis(theta_config: EllOrThetaConfig): +def generate_reals(theta_config: EllOrThetaConfig): """Generate theta and xi values from the configuration dictionary.""" thetas = _generate_ell_or_theta(**theta_config) xis = np.zeros_like(thetas) @@ -263,7 +260,7 @@ def use_source_factory( return source -def use_source_factory_metadata_only( +def use_source_factory_metadata_index( sacc_tracer: str, measurement: Measurement, wl_factory: WeakLensingFactory | None = None, @@ -402,7 +399,7 @@ def __init__( self.ell_or_theta_config: None | EllOrThetaConfig self.ell_or_theta_min: None | float | int self.ell_or_theta_max: None | float | int - self.window: None | Window + self.window: None | npt.NDArray[np.float64] self.data_vector: None | DataVector self.theory_vector: None | TheoryVector self.sacc_tracers: TracerNames @@ -449,118 +446,36 @@ def _set_ccl_kind(self, sacc_data_type): raise ValueError(f"The SACC data type {sacc_data_type} is not supported!") @classmethod - def _from_metadata( + def from_metadata_index( cls, - *, - sacc_data_type: str, - source0: Source, - source1: Source, - metadata: TwoPointCells | TwoPointCWindow | TwoPointXiTheta, - ) -> TwoPoint: - """Create a TwoPoint statistic from a TwoPointCells metadata object.""" - two_point = cls(sacc_data_type, source0, source1) - match metadata: - case TwoPointCells(): - two_point._init_from_cells(metadata) - case TwoPointCWindow(): - two_point._init_from_cwindow(metadata) - case TwoPointXiTheta(): - two_point._init_from_xi_theta(metadata) - case _: - raise ValueError(f"Metadata of type {type(metadata)} is not supported!") - return two_point - - def _init_from_cells(self, metadata: TwoPointCells): - """Initialize the TwoPoint statistic from a TwoPointCells metadata object.""" - self.sacc_tracers = metadata.XY.get_tracer_names() - self.ells = metadata.ells - self.window = None - if metadata.Cell is not None: - self.sacc_indices = metadata.Cell.indices - self.data_vector = DataVector.create(metadata.Cell.data) - self.ready = True - - def _init_from_cwindow(self, metadata: TwoPointCWindow): - """Initialize the TwoPoint statistic from a TwoPointCWindow metadata object.""" - self.sacc_tracers = metadata.XY.get_tracer_names() - self.window = metadata.window - if self.window.ells_for_interpolation is None: - self.window.ells_for_interpolation = calculate_ells_for_interpolation( - self.window - ) - if metadata.Cell is not None: - self.sacc_indices = metadata.Cell.indices - self.data_vector = DataVector.create(metadata.Cell.data) - self.ready = True - - def _init_from_xi_theta(self, metadata: TwoPointXiTheta): - """Initialize the TwoPoint statistic from a TwoPointXiTheta metadata object.""" - self.sacc_tracers = metadata.XY.get_tracer_names() - self.thetas = metadata.thetas - self.window = None - self.ells_for_xi = _ell_for_xi(**self.ell_for_xi_config) - if metadata.xis is not None: - self.sacc_indices = metadata.xis.indices - self.data_vector = DataVector.create(metadata.xis.data) - self.ready = True - - @classmethod - def _from_metadata_any( - cls, - metadata: Sequence[TwoPointCells | TwoPointCWindow | TwoPointXiTheta], + metadata_indices: Sequence[TwoPointHarmonicIndex | TwoPointRealIndex], wl_factory: WeakLensingFactory | None = None, nc_factory: NumberCountsFactory | None = None, ) -> UpdatableCollection[TwoPoint]: """Create an UpdatableCollection of TwoPoint statistics. This constructor creates an UpdatableCollection of TwoPoint statistics from a - list of TwoPointCells, TwoPointCWindow or TwoPointXiTheta metadata objects. - The metadata objects are used to initialize the TwoPoint statistics. - """ - two_point_list = [ - cls._from_metadata( - sacc_data_type=cell.get_sacc_name(), - source0=use_source_factory( - cell.XY.x, - cell.XY.x_measurement, - wl_factory=wl_factory, - nc_factory=nc_factory, - ), - source1=use_source_factory( - cell.XY.y, - cell.XY.y_measurement, - wl_factory=wl_factory, - nc_factory=nc_factory, - ), - metadata=cell, - ) - for cell in metadata - ] + list of TwoPointCellsIndex or TwoPointXiThetaIndex metadata index objects. The + purpose of this constructor is to create a TwoPoint statistic from metadata + index, which requires a follow-up call to `read` to read the data and metadata + from the SACC object. - return UpdatableCollection(two_point_list) - - @classmethod - def _from_metadata_only_any( - cls, - metadata: Sequence[TwoPointCellsIndex | TwoPointXiThetaIndex], - wl_factory: WeakLensingFactory | None = None, - nc_factory: NumberCountsFactory | None = None, - ) -> UpdatableCollection[TwoPoint]: - """Create an UpdatableCollection of TwoPoint statistics. + :param metadata_index: The metadata index objects to initialize the TwoPoint + statistics. + :param wl_factory: The weak lensing factory to use. + :param nc_factory: The number counts factory to use. - This constructor creates an UpdatableCollection of TwoPoint statistics from a - list of TwoPointCells, TwoPointCWindow or TwoPointXiTheta metadata objects. - The metadata objects are used to initialize the TwoPoint statistics. + :return: An UpdatableCollection of TwoPoint statistics. """ two_point_list = [] - for cell_index in metadata: - n1, a, n2, b = measurements_from_index(cell_index) + for metadata_index in metadata_indices: + n1, a, n2, b = measurements_from_index(metadata_index) two_point = cls( - sacc_data_type=cell_index["data_type"], - source0=use_source_factory_metadata_only( + sacc_data_type=metadata_index["data_type"], + source0=use_source_factory_metadata_index( n1, a, wl_factory=wl_factory, nc_factory=nc_factory ), - source1=use_source_factory_metadata_only( + source1=use_source_factory_metadata_index( n2, b, wl_factory=wl_factory, nc_factory=nc_factory ), ) @@ -569,72 +484,136 @@ def _from_metadata_only_any( return UpdatableCollection(two_point_list) @classmethod - def from_metadata_harmonic( + def _from_metadata_single( cls, - metadata: Sequence[TwoPointCells | TwoPointCWindow], + *, + metadata: TwoPointHarmonic | TwoPointReal, wl_factory: WeakLensingFactory | None = None, nc_factory: NumberCountsFactory | None = None, - check_consistence: bool = False, - ) -> UpdatableCollection[TwoPoint]: - """Create an UpdatableCollection of harmonic space TwoPoint statistics. + ) -> TwoPoint: + """Create a single TwoPoint statistic from metadata. - This constructor creates an UpdatableCollection of TwoPoint statistics from a - list of TwoPointCells or TwoPointCWindow metadata objects. The metadata objects - are used to initialize the TwoPoint statistics. + This constructor creates a single TwoPoint statistic from a TwoPointHarmonic or + TwoPointReal metadata object. It requires the sources to be initialized before + calling this constructor. The metadata object is used to initialize the TwoPoint + statistic. No further calls to `read` are needed. + """ + match metadata: + case TwoPointHarmonic(): + two_point = cls._from_metadata_single_base( + metadata, wl_factory, nc_factory + ) + two_point.ells = metadata.ells + two_point.window = metadata.window + case TwoPointReal(): + two_point = cls._from_metadata_single_base( + metadata, wl_factory, nc_factory + ) + two_point.thetas = metadata.thetas + two_point.window = None + two_point.ells_for_xi = _ell_for_xi(**two_point.ell_for_xi_config) + case _: + raise ValueError(f"Metadata of type {type(metadata)} is not supported!") + two_point.ready = True + return two_point + + @classmethod + def _from_metadata_single_base(cls, metadata, wl_factory, nc_factory): + """Create a single TwoPoint statistic from metadata. + + Base method for creating a single TwoPoint statistic from metadata. - :param metadata: The metadata objects to initialize the TwoPoint statistics. + :param metadata: The metadata object to initialize the TwoPoint statistic. :param wl_factory: The weak lensing factory to use. :param nc_factory: The number counts factory to use. - :param check_consistence: Whether to check the consistence of the - metadata and data. + + :return: A TwoPoint statistic. """ - if check_consistence: - check_two_point_consistence_harmonic(metadata) - return cls._from_metadata_any(metadata, wl_factory, nc_factory) + source0 = use_source_factory( + metadata.XY.x, + metadata.XY.x_measurement, + wl_factory=wl_factory, + nc_factory=nc_factory, + ) + source1 = use_source_factory( + metadata.XY.y, + metadata.XY.y_measurement, + wl_factory=wl_factory, + nc_factory=nc_factory, + ) + two_point = cls(metadata.get_sacc_name(), source0, source1) + two_point.sacc_tracers = metadata.XY.get_tracer_names() + return two_point @classmethod - def from_metadata_real( + def from_metadata( cls, - metadata: Sequence[TwoPointXiTheta], + metadata_seq: Sequence[TwoPointHarmonic | TwoPointReal], wl_factory: WeakLensingFactory | None = None, nc_factory: NumberCountsFactory | None = None, - check_consistence: bool = False, ) -> UpdatableCollection[TwoPoint]: - """Create an UpdatableCollection of real space TwoPoint statistics. + """Create an UpdatableCollection of TwoPoint statistics from metadata. This constructor creates an UpdatableCollection of TwoPoint statistics from a - list of TwoPointXiTheta metadata objects. The metadata objects are used to - initialize the TwoPoint statistics. + list of TwoPointHarmonic or TwoPointReal metadata objects. The metadata objects + are used to initialize the TwoPoint statistics. The sources are initialized + using the factories provided. + + Note that TwoPoint created with this constructor are ready to be used, but + contain no data. - :param metadata: The metadata objects to initialize the TwoPoint statistics. + :param metadata_seq: The metadata objects to initialize the TwoPoint statistics. :param wl_factory: The weak lensing factory to use. :param nc_factory: The number counts factory to use. - :param check_consistence: Whether to check the consistence of the - metadata and data. + + :return: An UpdatableCollection of TwoPoint statistics. """ - if check_consistence: - check_two_point_consistence_real(metadata) - return cls._from_metadata_any(metadata, wl_factory, nc_factory) + two_point_list = [ + cls._from_metadata_single( + metadata=metadata, wl_factory=wl_factory, nc_factory=nc_factory + ) + for metadata in metadata_seq + ] - @classmethod - def from_metadata_only_harmonic( - cls, - metadata: list[TwoPointCellsIndex], - wl_factory: WeakLensingFactory | None = None, - nc_factory: NumberCountsFactory | None = None, - ) -> UpdatableCollection[TwoPoint]: - """Create a TwoPoint from metadata only.""" - return cls._from_metadata_only_any(metadata, wl_factory, nc_factory) + return UpdatableCollection(two_point_list) @classmethod - def from_metadata_only_real( + def from_measurement( cls, - metadata: list[TwoPointXiThetaIndex], + measurements: Sequence[TwoPointMeasurement], wl_factory: WeakLensingFactory | None = None, nc_factory: NumberCountsFactory | None = None, ) -> UpdatableCollection[TwoPoint]: - """Create a TwoPoint from metadata only.""" - return cls._from_metadata_only_any(metadata, wl_factory, nc_factory) + """Create an UpdatableCollection of TwoPoint statistics from measurements. + + This constructor creates an UpdatableCollection of TwoPoint statistics from a + list of TwoPointMeasurement objects. The measurements are used to initialize the + TwoPoint statistics. The sources are initialized using the factories provided. + + Note that TwoPoint created with this constructor are ready to be used and + contain data. + + :param measurements: The measurements objects to initialize the TwoPoint + statistics. + :param wl_factory: The weak lensing factory to use. + :param nc_factory: The number counts factory to use. + + :return: An UpdatableCollection of TwoPoint statistics. + """ + two_point_list: list[TwoPoint] = [] + for measurement in measurements: + two_point = cls._from_metadata_single( + metadata=measurement.metadata, + wl_factory=wl_factory, + nc_factory=nc_factory, + ) + two_point.sacc_indices = measurement.indices + two_point.data_vector = DataVector.create(measurement.data) + two_point.ready = True + + two_point_list.append(two_point) + + return UpdatableCollection(two_point_list) def read_ell_cells( self, sacc_data_type: str, sacc_data: sacc.Sacc, tracers: TracerNames @@ -658,7 +637,7 @@ def read_ell_cells( return ells, Cells, sacc_indices - def read_theta_xis( + def read_reals( self, sacc_data_type: str, sacc_data: sacc.Sacc, tracers: TracerNames ) -> ( None @@ -666,7 +645,7 @@ def read_theta_xis( ): """Read and return theta and xi.""" thetas, xis = sacc_data.get_theta_xi(sacc_data_type, *tracers, return_cov=False) - # As version 0.13 of sacc, the method get_theta_xi returns the + # As version 0.13 of sacc, the method get_real returns the # theta values and the xi values in arrays of the same length. assert len(thetas) == len(xis) @@ -694,7 +673,7 @@ def read(self, sacc_data: sacc.Sacc) -> None: def read_real_space(self, sacc_data: sacc.Sacc): """Read the data for this statistic from the SACC file.""" - thetas_xis_indices = self.read_theta_xis( + thetas_xis_indices = self.read_reals( self.sacc_data_type, sacc_data, self.sacc_tracers ) # We do not support window functions for real space statistics @@ -722,7 +701,7 @@ def read_real_space(self, sacc_data: sacc.Sacc): "have no 2pt data in the SACC file and no input theta values " "were given!" ) - thetas, xis = generate_theta_xis(self.ell_or_theta_config) + thetas, xis = generate_reals(self.ell_or_theta_config) sacc_indices = None assert isinstance(self.ell_or_theta_min, (float, type(None))) assert isinstance(self.ell_or_theta_max, (float, type(None))) @@ -750,14 +729,16 @@ def read_harmonic_space(self, sacc_data: sacc.Sacc): f"specified `ell` in the configuration. `ell` is being ignored!", stacklevel=2, ) - window = extract_window_function(sacc_data, sacc_indices) + replacement_ells: None | npt.NDArray[np.int64] + window: None | npt.NDArray[np.float64] + replacement_ells, window = extract_window_function(sacc_data, sacc_indices) if window is not None: # When using a window function, we do not calculate all Cl's. # For this reason we have a default set of ells that we use # to compute Cl's, and we have a set of ells used for # interpolation. - window.ells_for_interpolation = calculate_ells_for_interpolation(window) - + assert replacement_ells is not None + ells = replacement_ells else: if self.ell_or_theta_config is None: # The SACC file has no data points, just a tracer, in this case we @@ -843,17 +824,16 @@ def compute_theory_vector_harmonic_space( scale1 = self.source1.get_scale() assert self.ccl_kind == "cl" - assert (self.ells is not None) or (self.window is not None) + assert self.ells is not None if self.window is not None: - # If a window function is provided, we need to compute the Cl's - # for the ells used in the window function. To do this, we will - # first compute the Cl's for the ells used in the interpolation - # and then interpolate the results to the ells used in the window - # function. - assert self.window.ells_for_interpolation is not None - cells_for_interpolation = self.compute_cells( - self.window.ells_for_interpolation, + ells_for_interpolation = calculate_ells_for_interpolation( + self.ells[0], self.ells[-1] + ) + + cells_interpolated = self.compute_cells_interpolated( + self.ells, + ells_for_interpolation, scale0, scale1, tools, @@ -861,26 +841,11 @@ def compute_theory_vector_harmonic_space( tracers1, ) - # TODO: There is no code in Firecrown, neither test nor example, - # that exercises a theory window function in any way. - cell_interpolator = make_log_interpolator( - self.window.ells_for_interpolation, cells_for_interpolation - ) - # Deal with ell=0 and ell=1 - cells_interpolated = np.zeros(self.window.ells.size) - cells_interpolated[2:] = cell_interpolator(self.window.ells[2:]) - - # Here we left multiply the computed Cl's by the window function - # to get the final Cl's. - theory_vector = np.einsum( - "lb, l -> b", - self.window.weights, - cells_interpolated, - ) + # Here we left multiply the computed Cl's by the window function to get the + # final Cl's. + theory_vector = np.einsum("lb, l -> b", self.window, cells_interpolated) # We also compute the mean ell value associated with each bin. - self.mean_ells = np.einsum( - "lb, l -> b", self.window.weights, self.window.ells - ) + self.mean_ells = np.einsum("lb, l -> b", self.window, self.ells) assert self.data_vector is not None return TheoryVector.create(theory_vector) @@ -939,6 +904,46 @@ def compute_cells( theory_vector = self.cells[TRACER_NAMES_TOTAL] return theory_vector + def compute_cells_interpolated( + self, + ells: npt.NDArray[np.int64], + ells_for_interpolation: npt.NDArray[np.int64], + scale0: float, + scale1: float, + tools: ModelingTools, + tracers0: Sequence[Tracer], + tracers1: Sequence[Tracer], + ) -> npt.NDArray[np.float64]: + """Compute the interpolated power spectrum for the given ells and tracers. + + :param ells: The angular wavenumbers at which to compute the power spectrum. + :param ells_for_interpolation: The angular wavenumbers at which the power + spectrum is computed for interpolation. + :param scale0: The scale factor for the first tracer. + :param scale1: The scale factor for the second tracer. + :param tools: The modeling tools to use. + :param tracers0: The first tracers to use. + :param tracers1: The second tracers to use. + + Compute the power spectrum for the given ells and tracers and interpolate + the result to the ells provided. + + :return: The interpolated power spectrum. + """ + computed_cells = self.compute_cells( + ells_for_interpolation, scale0, scale1, tools, tracers0, tracers1 + ) + cell_interpolator = make_log_interpolator( + ells_for_interpolation, computed_cells + ) + cell_interpolated = np.zeros(len(ells)) + # We should not interpolate ell 0 and 1 + ells_larger_than_1 = ells > 1 + cell_interpolated[ells_larger_than_1] = cell_interpolator( + ells[ells_larger_than_1] + ) + return cell_interpolated + def calculate_pk( self, pk_name: str, tools: ModelingTools, tracer0: Tracer, tracer1: Tracer ): diff --git a/firecrown/likelihood/weak_lensing.py b/firecrown/likelihood/weak_lensing.py index b1102df8..677083ad 100644 --- a/firecrown/likelihood/weak_lensing.py +++ b/firecrown/likelihood/weak_lensing.py @@ -25,7 +25,7 @@ SourceGalaxySystematic, Tracer, ) -from firecrown.metadata.two_point import InferredGalaxyZDist +from firecrown.metadata_types import InferredGalaxyZDist from firecrown.modeling_tools import ModelingTools from firecrown.parameters import ParamsMap diff --git a/firecrown/metadata/__init__.py b/firecrown/metadata/__init__.py deleted file mode 100644 index f72b566e..00000000 --- a/firecrown/metadata/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Classes used to define and implement likelihood metadata. - -""" - -# flake8: noqa diff --git a/firecrown/metadata/two_point.py b/firecrown/metadata_functions.py similarity index 56% rename from firecrown/metadata/two_point.py rename to firecrown/metadata_functions.py index 1930bd5a..013db0ab 100644 --- a/firecrown/metadata/two_point.py +++ b/firecrown/metadata_functions.py @@ -1,28 +1,24 @@ -"""This module deals with two-point functions metadata. +"""This module deals with two-point metadata functions. -It contains all data classes and functions for store and extract two-point functions -metadata from a sacc file. +It contains functions used to manipulate two-point metadata, including extracting +metadata from a sacc file and creating new metadata objects. """ from itertools import combinations_with_replacement, product -import hashlib -from typing import TypedDict, Sequence +from typing import TypedDict import numpy as np import numpy.typing as npt import sacc from sacc.data_types import required_tags -from firecrown.metadata.two_point_types import ( +from firecrown.metadata_types import ( TracerNames, Measurement, InferredGalaxyZDist, TwoPointXY, - Window, - TwoPointCells, - TwoPointCWindow, - TwoPointXiTheta, - TwoPointMeasurement, + TwoPointHarmonic, + TwoPointReal, LENS_REGEX, SOURCE_REGEX, MEASURED_TYPE_STRING_MAP, @@ -31,32 +27,28 @@ GALAXY_SOURCE_TYPES, ) - -# TwoPointXiThetaIndex is a type used to create intermediate objects when -# reading SACC objects. They should not be seen directly by users of Firecrown. -TwoPointXiThetaIndex = TypedDict( - "TwoPointXiThetaIndex", +# TwoPointRealIndex is a type used to create intermediate objects when reading SACC +# objects. They should not be seen directly by users of Firecrown. +TwoPointRealIndex = TypedDict( + "TwoPointRealIndex", { "data_type": str, "tracer_names": TracerNames, - "thetas": npt.NDArray[np.float64], }, ) - -# TwoPointCellsIndex is a type used to create intermediate objects when reading -# SACC objects. They should not be seen directly by users of Firecrown. -TwoPointCellsIndex = TypedDict( - "TwoPointCellsIndex", +# TwoPointHarmonicIndex is a type used to create intermediate objects when reading SACC +# objects. They should not be seen directly by users of Firecrown. +TwoPointHarmonicIndex = TypedDict( + "TwoPointHarmonicIndex", { "data_type": str, "tracer_names": TracerNames, - "ells": npt.NDArray[np.int64], }, ) -def _extract_all_candidate_data_types( +def _extract_all_candidate_measurement_types( data_points: list[sacc.DataPoint], include_maybe_types: bool = False, ) -> dict[str, set[Measurement]]: @@ -159,7 +151,7 @@ def match_name_type( return False, tracer1, a, tracer2, b -def extract_all_tracers( +def extract_all_tracers_inferred_galaxy_zdists( sacc_data: sacc.Sacc, include_maybe_types=False ) -> list[InferredGalaxyZDist]: """Extracts the two-point function metadata from a Sacc object. @@ -171,7 +163,7 @@ def extract_all_tracers( and returns it in a list. """ tracers: list[sacc.tracers.BaseTracer] = sacc_data.tracers.values() - tracer_types = extract_all_tracers_types( + tracer_types = extract_all_measured_types( sacc_data, include_maybe_types=include_maybe_types ) for tracer0, tracer_types0 in tracer_types.items(): @@ -192,7 +184,7 @@ def extract_all_tracers( ] -def extract_all_tracers_types( +def extract_all_measured_types( sacc_data: sacc.Sacc, include_maybe_types: bool = False, ) -> dict[str, set[Measurement]]: @@ -206,13 +198,13 @@ def extract_all_tracers_types( """ data_points = sacc_data.get_data_points() - return _extract_all_candidate_data_types(data_points, include_maybe_types) + return _extract_all_candidate_measurement_types(data_points, include_maybe_types) -def extract_all_data_types_xi_thetas( +def extract_all_real_metadata_indices( sacc_data: sacc.Sacc, allowed_data_type: None | list[str] = None, -) -> list[TwoPointXiThetaIndex]: +) -> list[TwoPointRealIndex]: """Extract all two-point function metadata from a sacc file. Extracts the two-point function measurement metadata for all measurements @@ -222,40 +214,37 @@ def extract_all_data_types_xi_thetas( data_types = sacc_data.get_data_types() - data_types_xi_thetas = [ + data_types_reals = [ data_type for data_type in data_types if tag_name in required_tags[data_type] ] if allowed_data_type is not None: - data_types_xi_thetas = [ + data_types_reals = [ data_type - for data_type in data_types_xi_thetas + for data_type in data_types_reals if data_type in allowed_data_type ] - all_xi_thetas: list[TwoPointXiThetaIndex] = [] - for data_type in data_types_xi_thetas: + all_real_indices: list[TwoPointRealIndex] = [] + for data_type in data_types_reals: for combo in sacc_data.get_tracer_combinations(data_type): if len(combo) != 2: raise ValueError( f"Tracer combination {combo} does not have exactly two tracers." ) - all_xi_thetas.append( + all_real_indices.append( { "data_type": data_type, "tracer_names": TracerNames(*combo), - "thetas": np.array( - sacc_data.get_tag(tag_name, data_type=data_type, tracers=combo) - ), } ) - return all_xi_thetas + return all_real_indices -def extract_all_data_types_cells( +def extract_all_harmonic_metadata_indices( sacc_data: sacc.Sacc, allowed_data_type: None | list[str] = None -) -> list[TwoPointCellsIndex]: +) -> list[TwoPointHarmonicIndex]: """Extracts the two-point function metadata from a sacc file.""" tag_name = "ell" @@ -272,7 +261,7 @@ def extract_all_data_types_cells( if data_type in allowed_data_type ] - all_cell: list[TwoPointCellsIndex] = [] + all_harmonic_indices: list[TwoPointHarmonicIndex] = [] for data_type in data_types_cells: for combo in sacc_data.get_tracer_combinations(data_type): if len(combo) != 2: @@ -280,263 +269,116 @@ def extract_all_data_types_cells( f"Tracer combination {combo} does not have exactly two tracers." ) - all_cell.append( + all_harmonic_indices.append( { "data_type": data_type, "tracer_names": TracerNames(*combo), - "ells": np.array( - sacc_data.get_tag(tag_name, data_type=data_type, tracers=combo) - ).astype(np.int64), } ) - return all_cell - - -def extract_all_photoz_bin_combinations( - sacc_data: sacc.Sacc, - include_maybe_types: bool = False, -) -> list[TwoPointXY]: - """Extracts the two-point function metadata from a sacc file.""" - inferred_galaxy_zdists = extract_all_tracers( - sacc_data, include_maybe_types=include_maybe_types - ) - bin_combinations = make_all_photoz_bin_combinations(inferred_galaxy_zdists) - - return bin_combinations - - -def extract_window_function( - sacc_data: sacc.Sacc, indices: npt.NDArray[np.int64] -) -> None | Window: - """Extract a window function from a sacc file that matches the given indices. - - If there is no appropriate window function, return None. - """ - bandpower_window = sacc_data.get_bandpower_windows(indices) - if bandpower_window is None: - return None - return Window( - ells=bandpower_window.values, - weights=bandpower_window.weight / bandpower_window.weight.sum(axis=0), - ) - - -def _build_two_point_xy( - inferred_galaxy_zdists_dict, tracer_names, data_type -) -> TwoPointXY: - """Build a TwoPointXY object from the inferred galaxy z distributions. - - The TwoPointXY object is built from the inferred galaxy z distributions, the data - type, and the tracer names. - """ - a, b = MEASURED_TYPE_STRING_MAP[data_type] - - igz1 = inferred_galaxy_zdists_dict[tracer_names[0]] - igz2 = inferred_galaxy_zdists_dict[tracer_names[1]] - - ab = a in igz1.measurements and b in igz2.measurements - ba = b in igz1.measurements and a in igz2.measurements - if a != b and ab and ba: - raise ValueError( - f"Ambiguous measurements for tracers {tracer_names}. " - f"Impossible to determine which measurement is from which tracer." - ) - XY = TwoPointXY( - x=igz1, y=igz2, x_measurement=a if ab else b, y_measurement=b if ab else a - ) - - return XY + return all_harmonic_indices -def extract_all_data_cells( +def extract_all_harmonic_metadata( sacc_data: sacc.Sacc, allowed_data_type: None | list[str] = None, include_maybe_types=False, -) -> tuple[list[TwoPointCells], list[TwoPointCWindow]]: +) -> list[TwoPointHarmonic]: """Extract the two-point function metadata and data from a sacc file.""" inferred_galaxy_zdists_dict = { igz.bin_name: igz - for igz in extract_all_tracers( + for igz in extract_all_tracers_inferred_galaxy_zdists( sacc_data, include_maybe_types=include_maybe_types ) } - if sacc_data.covariance is None or sacc_data.covariance.dense is None: - raise ValueError("The SACC object does not have a covariance matrix.") - cov_hash = hashlib.sha256(sacc_data.covariance.dense).hexdigest() - - two_point_cells = [] - two_point_cwindows = [] - for cell_index in extract_all_data_types_cells(sacc_data, allowed_data_type): + result: list[TwoPointHarmonic] = [] + for cell_index in extract_all_harmonic_metadata_indices( + sacc_data, allowed_data_type + ): tracer_names = cell_index["tracer_names"] - ells = cell_index["ells"] - data_type = cell_index["data_type"] + dt = cell_index["data_type"] - XY = _build_two_point_xy(inferred_galaxy_zdists_dict, tracer_names, data_type) + XY = make_two_point_xy(inferred_galaxy_zdists_dict, tracer_names, dt) - ells, Cells, indices = sacc_data.get_ell_cl( - data_type=data_type, - tracer1=tracer_names[0], - tracer2=tracer_names[1], - return_cov=False, - return_ind=True, + t1, t2 = tracer_names + ells, _, indices = sacc_data.get_ell_cl( + data_type=dt, tracer1=t1, tracer2=t2, return_cov=False, return_ind=True ) - window = extract_window_function(sacc_data, indices) - if window is not None: - two_point_cwindows.append( - TwoPointCWindow( - XY=XY, - window=window, - Cell=TwoPointMeasurement( - data=Cells, - indices=indices, - covariance_name=cov_hash, - ), - ) - ) - else: - two_point_cells.append( - TwoPointCells( - XY=XY, - ells=ells, - Cell=TwoPointMeasurement( - data=Cells, - indices=indices, - covariance_name=cov_hash, - ), - ) - ) + replacement_ells, weights = extract_window_function(sacc_data, indices) + if replacement_ells is not None: + ells = replacement_ells + + result.append(TwoPointHarmonic(XY=XY, window=weights, ells=ells)) - return two_point_cells, two_point_cwindows + return result -def extract_all_data_xi_thetas( +# Extracting all real metadata from a SACC object. + + +def extract_all_real_metadata( sacc_data: sacc.Sacc, allowed_data_type: None | list[str] = None, include_maybe_types=False, -) -> list[TwoPointXiTheta]: +) -> list[TwoPointReal]: """Extract the two-point function metadata and data from a sacc file.""" inferred_galaxy_zdists_dict = { igz.bin_name: igz - for igz in extract_all_tracers( + for igz in extract_all_tracers_inferred_galaxy_zdists( sacc_data, include_maybe_types=include_maybe_types ) } - cov_hash = hashlib.sha256(sacc_data.covariance.dense).hexdigest() + tprs: list[TwoPointReal] = [] + for real_index in extract_all_real_metadata_indices(sacc_data, allowed_data_type): + tracer_names = real_index["tracer_names"] + dt = real_index["data_type"] - two_point_xi_thetas = [] - for xi_theta_index in extract_all_data_types_xi_thetas( - sacc_data, allowed_data_type - ): - tracer_names = xi_theta_index["tracer_names"] - thetas = xi_theta_index["thetas"] - data_type = xi_theta_index["data_type"] - - XY = _build_two_point_xy(inferred_galaxy_zdists_dict, tracer_names, data_type) - - thetas, Xis, indices = sacc_data.get_theta_xi( - data_type=data_type, - tracer1=tracer_names[0], - tracer2=tracer_names[1], - return_cov=False, - return_ind=True, - ) + XY = make_two_point_xy(inferred_galaxy_zdists_dict, tracer_names, dt) - Xi = TwoPointMeasurement( - data=Xis, - indices=indices, - covariance_name=cov_hash, + t1, t2 = tracer_names + thetas, _, _ = sacc_data.get_theta_xi( + data_type=dt, tracer1=t1, tracer2=t2, return_cov=False, return_ind=True ) - two_point_xi_thetas.append(TwoPointXiTheta(XY=XY, thetas=thetas, xis=Xi)) - - return two_point_xi_thetas - + tprs.append(TwoPointReal(XY=XY, thetas=thetas)) -def check_two_point_consistence_harmonic( - two_point_cells: Sequence[TwoPointCells | TwoPointCWindow], -) -> None: - """Check the indices of the harmonic-space two-point functions. + return tprs - Make sure the indices of the harmonic-space two-point functions are consistent. - """ - all_indices_set: set[int] = set() - index_set_list = [] - cov_name: None | str = None - for two_point_cell in two_point_cells: - if two_point_cell.Cell is None: - raise ValueError( - f"The TwoPointCells {two_point_cell} does not contain a data." - ) - if cov_name is None: - cov_name = two_point_cell.Cell.covariance_name - elif cov_name != two_point_cell.Cell.covariance_name: - raise ValueError( - f"The TwoPointCells {two_point_cell} has a different covariance name " - f"{two_point_cell.Cell.covariance_name} than the previous " - f"TwoPointCells {cov_name}." - ) - index_set = set(two_point_cell.Cell.indices) - index_set_list.append(index_set) - if len(index_set) != len(two_point_cell.Cell.indices): - raise ValueError( - f"The indices of the TwoPointCells {two_point_cell} are not unique." - ) - - if all_indices_set & index_set: - for i, index_set_a in enumerate(index_set_list): - if index_set_a & index_set: - raise ValueError( - f"The indices of the TwoPointCells {two_point_cells[i]} and " - f"{two_point_cell} overlap." - ) - all_indices_set.update(index_set) +def extract_all_photoz_bin_combinations( + sacc_data: sacc.Sacc, + include_maybe_types: bool = False, +) -> list[TwoPointXY]: + """Extracts the two-point function metadata from a sacc file.""" + inferred_galaxy_zdists = extract_all_tracers_inferred_galaxy_zdists( + sacc_data, include_maybe_types=include_maybe_types + ) + bin_combinations = make_all_photoz_bin_combinations(inferred_galaxy_zdists) + return bin_combinations -def check_two_point_consistence_real( - two_point_xi_thetas: Sequence[TwoPointXiTheta], -) -> None: - """Check the indices of the real-space two-point functions. - Make sure the indices of the real-space two-point functions are consistent. +def extract_window_function( + sacc_data: sacc.Sacc, indices: npt.NDArray[np.int64] +) -> tuple[None | npt.NDArray[np.int64], None | npt.NDArray[np.float64]]: + """Extract ells and weights for a window function. + + :params sacc_data: the Sacc object from which we read. + :params indices: the indices of the data points in the Sacc object which + are computed by the window function. + :returns: the ells and weights of the window function that match the + given indices from a sacc object, or a tuple of (None, None) + if the indices represent the measured Cells directly. """ - all_indices_set: set[int] = set() - index_set_list = [] - cov_name: None | str = None - - for two_point_xi_theta in two_point_xi_thetas: - if two_point_xi_theta.xis is None: - raise ValueError( - f"The TwoPointXiTheta {two_point_xi_theta} does not contain a data." - ) - if cov_name is None: - cov_name = two_point_xi_theta.xis.covariance_name - elif cov_name != two_point_xi_theta.xis.covariance_name: - raise ValueError( - f"The TwoPointXiTheta {two_point_xi_theta} has a different covariance " - f"name {two_point_xi_theta.xis.covariance_name} than the previous " - f"TwoPointXiTheta {cov_name}." - ) - index_set = set(two_point_xi_theta.xis.indices) - index_set_list.append(index_set) - if len(index_set) != len(two_point_xi_theta.xis.indices): - raise ValueError( - f"The indices of the TwoPointXiTheta {two_point_xi_theta} " - f"are not unique." - ) - - if all_indices_set & index_set: - for i, index_set_a in enumerate(index_set_list): - if index_set_a & index_set: - raise ValueError( - f"The indices of the TwoPointXiTheta {two_point_xi_thetas[i]} " - f"and {two_point_xi_theta} overlap." - ) - all_indices_set.update(index_set) + bandpower_window = sacc_data.get_bandpower_windows(indices) + if bandpower_window is None: + return None, None + ells = bandpower_window.values + weights = bandpower_window.weight / bandpower_window.weight.sum(axis=0) + return ells, weights def make_all_photoz_bin_combinations( @@ -557,8 +399,43 @@ def make_all_photoz_bin_combinations( return bin_combinations +def make_two_point_xy( + inferred_galaxy_zdists_dict: dict[str, InferredGalaxyZDist], + tracer_names: TracerNames, + data_type: str, +) -> TwoPointXY: + """Build a TwoPointXY object from the inferred galaxy z distributions. + + The TwoPointXY object is built from the inferred galaxy z distributions, the data + type, and the tracer names. + + :param inferred_galaxy_zdists_dict: a dictionary of inferred galaxy z distributions. + :param tracer_names: a tuple of tracer names. + :param data_type: the data type. + + :returns: a TwoPointXY object. + """ + a, b = MEASURED_TYPE_STRING_MAP[data_type] + + igz1 = inferred_galaxy_zdists_dict[tracer_names[0]] + igz2 = inferred_galaxy_zdists_dict[tracer_names[1]] + + ab = a in igz1.measurements and b in igz2.measurements + ba = b in igz1.measurements and a in igz2.measurements + if a != b and ab and ba: + raise ValueError( + f"Ambiguous measurements for tracers {tracer_names}. " + f"Impossible to determine which measurement is from which tracer." + ) + XY = TwoPointXY( + x=igz1, y=igz2, x_measurement=a if ab else b, y_measurement=b if ab else a + ) + + return XY + + def measurements_from_index( - index: TwoPointXiThetaIndex | TwoPointCellsIndex, + index: TwoPointRealIndex | TwoPointHarmonicIndex, ) -> tuple[str, Measurement, str, Measurement]: """Return the measurements from a TwoPointXiThetaIndex object.""" a, b = MEASURED_TYPE_STRING_MAP[index["data_type"]] diff --git a/firecrown/metadata/two_point_types.py b/firecrown/metadata_types.py similarity index 74% rename from firecrown/metadata/two_point_types.py rename to firecrown/metadata_types.py index 6b8f644c..50486fe1 100644 --- a/firecrown/metadata/two_point_types.py +++ b/firecrown/metadata_types.py @@ -1,6 +1,6 @@ -"""This module deals with two-point types. +"""This module deals with metadata types. -This module contains two-point types definitions. +This module contains metadata types definitions. """ from itertools import chain, combinations_with_replacement @@ -11,7 +11,7 @@ import numpy as np import numpy.typing as npt -from firecrown.utils import compare_optional_arrays, compare_optionals, YAMLSerializable +from firecrown.utils import compare_optional_arrays, YAMLSerializable @dataclass(frozen=True) @@ -312,37 +312,6 @@ def __eq__(self, other): ) -@dataclass(frozen=True, kw_only=True) -class TwoPointMeasurement(YAMLSerializable): - """Class defining the metadata for a two-point measurement. - - The class used to store the metadata for a two-point function measured on a sphere. - - This includes the measured two-point function and their indices in the covariance - matrix. - """ - - data: npt.NDArray[np.float64] - indices: npt.NDArray[np.int64] - covariance_name: str - - def __post_init__(self) -> None: - """Make sure the data and indices have the same shape.""" - if len(self.data.shape) != 1: - raise ValueError("Data should be a 1D array.") - - if self.data.shape != self.indices.shape: - raise ValueError("Data and indices should have the same shape.") - - def __eq__(self, other) -> bool: - """Equality test for TwoPointMeasurement objects.""" - return ( - np.array_equal(self.data, other.data) - and np.array_equal(self.indices, other.indices) - and self.covariance_name == other.covariance_name - ) - - def make_measurements_dict(value: set[Measurement]) -> list[dict[str, str]]: """Create a dictionary from a Measurement object. @@ -526,7 +495,7 @@ def get_tracer_names(self) -> TracerNames: @dataclass(frozen=True, kw_only=True) -class TwoPointCells(YAMLSerializable): +class TwoPointHarmonic(YAMLSerializable): """Class defining the metadata for an harmonic-space two-point measurement. The class used to store the metadata for a (spherical) harmonic-space two-point @@ -539,10 +508,10 @@ class TwoPointCells(YAMLSerializable): XY: TwoPointXY ells: npt.NDArray[np.int64] - Cell: None | TwoPointMeasurement = None + window: None | npt.NDArray[np.float64] = None def __post_init__(self) -> None: - """Validate the TwoPointCells data. + """Validate the TwoPointHarmonic data. Make sure the ells are a 1D array and X and Y are compatible with harmonic-space calculations. @@ -550,8 +519,13 @@ def __post_init__(self) -> None: if len(self.ells.shape) != 1: raise ValueError("Ells should be a 1D array.") - if self.Cell is not None and self.Cell.data.shape != self.ells.shape: - raise ValueError("Cell should have the same shape as ells.") + if self.window is not None: + if not isinstance(self.window, np.ndarray): + raise ValueError("window should be a ndarray.") + if len(self.window.shape) != 2: + raise ValueError("window should be a 2D array.") + if self.window.shape[0] != len(self.ells): + raise ValueError("window should have the same number of rows as ells.") if not measurement_supports_harmonic( self.XY.x_measurement @@ -562,15 +536,18 @@ def __post_init__(self) -> None: ) def __eq__(self, other) -> bool: - """Equality test for TwoPointCells objects.""" + """Equality test for TwoPointHarmonic objects.""" + if not isinstance(other, TwoPointHarmonic): + raise ValueError("Can only compare TwoPointHarmonic objects.") + return ( self.XY == other.XY and np.array_equal(self.ells, other.ells) - and compare_optionals(self.Cell, other.Cell) + and compare_optional_arrays(self.window, other.window) ) def __str__(self) -> str: - """Return a string representation of the TwoPointCells object.""" + """Return a string representation of the TwoPointHarmonic object.""" return f"{self.XY}[{self.get_sacc_name()}]" def get_sacc_name(self) -> str: @@ -579,125 +556,18 @@ def get_sacc_name(self) -> str: self.XY.x_measurement, self.XY.y_measurement ) - def has_data(self) -> bool: - """Return True if the TwoPointCells object has a Cell array.""" - return self.Cell is not None - - -@dataclass(kw_only=True) -class Window(YAMLSerializable): - """The class used to represent a window function. - - It contains the ells at which the window function is defined, the weights - of the window function, and the ells at which the window function is - interpolated. - - It may contain the ells for interpolation if the theory prediction is - calculated at a different set of ells than the window function. - """ - - ells: npt.NDArray[np.int64] - weights: npt.NDArray[np.float64] - ells_for_interpolation: None | npt.NDArray[np.int64] = None - - def __post_init__(self) -> None: - """Make sure the weights have the right shape.""" - if len(self.ells.shape) != 1: - raise ValueError("Ells should be a 1D array.") - if len(self.weights.shape) != 2: - raise ValueError("Weights should be a 2D array.") - if self.weights.shape[0] != len(self.ells): - raise ValueError("Weights should have the same number of rows as ells.") - if ( - self.ells_for_interpolation is not None - and len(self.ells_for_interpolation.shape) != 1 - ): - raise ValueError("Ells for interpolation should be a 1D array.") - def n_observations(self) -> int: - """Return the number of observations supported by the window function.""" - return self.weights.shape[1] - - def __eq__(self, other) -> bool: - """Equality test for Window objects.""" - assert isinstance(other, Window) - # We will need special handling for the optional ells_for_interpolation. - # First handle the non-optinal parts. - partial_result = np.array_equal(self.ells, other.ells) and np.array_equal( - self.weights, other.weights - ) - if not partial_result: - return False - return compare_optional_arrays( - self.ells_for_interpolation, other.ells_for_interpolation - ) - - -@dataclass(frozen=True, kw_only=True) -class TwoPointCWindow(YAMLSerializable): - """Two-point function with a window function. - - The class used to store the metadata for a (spherical) harmonic-space two-point - function measured on a sphere, with an associated window function. - - This includes the two redshift resolutions (one for each binned quantity) and the - matrix (window function) that relates the measured Cl's with the predicted Cl's. - - Note that the matrix `window` always has l=0 and l=1 suppressed. - """ - - XY: TwoPointXY - window: Window - Cell: None | TwoPointMeasurement = None - - def __post_init__(self): - """Validate the TwoPointCWindow data. + """Return the number of observations described by these metadata. - Make sure the window is + :return: The number of observations. """ - if not isinstance(self.window, Window): - raise ValueError("Window should be a Window object.") - - if self.Cell is not None: - if len(self.Cell.data) != self.window.n_observations(): - raise ValueError( - "Data should have the same number of elements as the number of " - "observations supported by the window function." - ) - - if not measurement_supports_harmonic( - self.XY.x_measurement - ) or not measurement_supports_harmonic(self.XY.y_measurement): - raise ValueError( - f"Measurements {self.XY.x_measurement} and " - f"{self.XY.y_measurement} must support harmonic-space calculations." - ) - - def __str__(self) -> str: - """Return a string representation of the TwoPointCWindow object.""" - return f"{self.XY}[{self.get_sacc_name()}]" - - def get_sacc_name(self) -> str: - """Return the SACC name for the two-point function.""" - return type_to_sacc_string_harmonic( - self.XY.x_measurement, self.XY.y_measurement - ) - - def __eq__(self, other) -> bool: - """Equality test for TwoPointCWindow objects.""" - return ( - self.XY == other.XY - and self.window == other.window - and compare_optionals(self.Cell, other.Cell) - ) - - def has_data(self) -> bool: - """Return True if the TwoPointCWindow object has a Cell array.""" - return self.Cell is not None + if self.window is None: + return self.ells.shape[0] + return self.window.shape[1] @dataclass(frozen=True, kw_only=True) -class TwoPointXiTheta(YAMLSerializable): +class TwoPointReal(YAMLSerializable): """Class defining the metadata for a real-space two-point measurement. The class used to store the metadata for a real-space two-point function measured @@ -710,7 +580,6 @@ class TwoPointXiTheta(YAMLSerializable): XY: TwoPointXY thetas: npt.NDArray[np.float64] - xis: None | TwoPointMeasurement = None def __post_init__(self): """Validate the TwoPointCWindow data. @@ -720,9 +589,6 @@ def __post_init__(self): if len(self.thetas.shape) != 1: raise ValueError("Thetas should be a 1D array.") - if self.xis is not None and self.xis.data.shape != self.thetas.shape: - raise ValueError("Xis should have the same shape as thetas.") - if not measurement_supports_real( self.XY.x_measurement ) or not measurement_supports_real(self.XY.y_measurement): @@ -732,7 +598,7 @@ def __post_init__(self): ) def __str__(self) -> str: - """Return a string representation of the TwoPointXiTheta object.""" + """Return a string representation of the TwoPointReal object.""" return f"{self.XY}[{self.get_sacc_name()}]" def get_sacc_name(self) -> str: @@ -740,13 +606,15 @@ def get_sacc_name(self) -> str: return type_to_sacc_string_real(self.XY.x_measurement, self.XY.y_measurement) def __eq__(self, other) -> bool: - """Equality test for TwoPointXiTheta objects.""" - return ( - self.XY == other.XY - and np.array_equal(self.thetas, other.thetas) - and compare_optionals(self.xis, other.xis) - ) + """Equality test for TwoPointReal objects.""" + if not isinstance(other, TwoPointReal): + raise ValueError("Can only compare TwoPointReal objects.") - def has_data(self) -> bool: - """Return True if the TwoPointXiTheta object has a xis array.""" - return self.xis is not None + return self.XY == other.XY and np.array_equal(self.thetas, other.thetas) + + def n_observations(self) -> int: + """Return the number of observations described by these metadata. + + :return: The number of observations. + """ + return self.thetas.shape[0] diff --git a/tests/conftest.py b/tests/conftest.py index ee9813e7..df9789cc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,19 +18,15 @@ from firecrown.parameters import ParamsMap from firecrown.connector.mapping import MappingCosmoSIS, mapping_builder from firecrown.modeling_tools import ModelingTools -from firecrown.metadata.two_point_types import ( +from firecrown.metadata_types import ( Galaxies, - measurement_is_compatible_harmonic, - measurement_is_compatible_real, -) -from firecrown.metadata.two_point import ( - TracerNames, InferredGalaxyZDist, - Window, + TracerNames, + TwoPointHarmonic, TwoPointXY, - TwoPointCells, - TwoPointCWindow, - TwoPointXiTheta, + TwoPointReal, + measurement_is_compatible_harmonic, + measurement_is_compatible_real, ) import firecrown.likelihood.weak_lensing as wl import firecrown.likelihood.number_counts as nc @@ -169,14 +165,18 @@ def fixture_tools_with_vanilla_cosmology(): @pytest.fixture( name="harmonic_bin_1", params=[Galaxies.COUNTS, Galaxies.SHEAR_E], + ids=["counts", "shear_e"], ) def make_harmonic_bin_1(request) -> InferredGalaxyZDist: """Generate an InferredGalaxyZDist object with 5 bins.""" + z = np.linspace(0.0, 1.0, 256) # Necessary to match the default lensing kernel size + z_mean = 0.5 + z_sigma = 0.05 + dndz = np.exp(-0.5 * (z - z_mean) ** 2 / z_sigma**2) / ( + np.sqrt(2 * np.pi) * z_sigma + ) x = InferredGalaxyZDist( - bin_name="bin_1", - z=np.linspace(0, 1, 5), - dndz=np.array([0.1, 0.5, 0.2, 0.3, 0.4]), - measurements={request.param}, + bin_name="bin_1", z=z, dndz=dndz, measurements={request.param} ) return x @@ -184,14 +184,18 @@ def make_harmonic_bin_1(request) -> InferredGalaxyZDist: @pytest.fixture( name="harmonic_bin_2", params=[Galaxies.COUNTS, Galaxies.SHEAR_E], + ids=["counts", "shear_e"], ) def make_harmonic_bin_2(request) -> InferredGalaxyZDist: """Generate an InferredGalaxyZDist object with 3 bins.""" + z = np.linspace(0.0, 1.0, 256) # Necessary to match the default lensing kernel size + z_mean = 0.6 + z_sigma = 0.05 + dndz = np.exp(-0.5 * (z - z_mean) ** 2 / z_sigma**2) / ( + np.sqrt(2 * np.pi) * z_sigma + ) x = InferredGalaxyZDist( - bin_name="bin_2", - z=np.linspace(0, 1, 3), - dndz=np.array([0.1, 0.5, 0.4]), - measurements={request.param}, + bin_name="bin_2", z=z, dndz=dndz, measurements={request.param} ) return x @@ -204,6 +208,7 @@ def make_harmonic_bin_2(request) -> InferredGalaxyZDist: Galaxies.SHEAR_MINUS, Galaxies.SHEAR_PLUS, ], + ids=["counts", "shear_t", "shear_minus", "shear_plus"], ) def make_real_bin_1(request) -> InferredGalaxyZDist: """Generate an InferredGalaxyZDist object with 5 bins.""" @@ -224,6 +229,7 @@ def make_real_bin_1(request) -> InferredGalaxyZDist: Galaxies.SHEAR_MINUS, Galaxies.SHEAR_PLUS, ], + ids=["counts", "shear_t", "shear_minus", "shear_plus"], ) def make_real_bin_2(request) -> InferredGalaxyZDist: """Generate an InferredGalaxyZDist object with 3 bins.""" @@ -237,18 +243,12 @@ def make_real_bin_2(request) -> InferredGalaxyZDist: @pytest.fixture(name="window_1") -def make_window_1() -> Window: +def make_window_1() -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]: """Generate a Window object with 100 ells.""" ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - ells_for_interpolation = np.array(np.linspace(0, 100, 100), dtype=np.int64) weights = np.ones(400).reshape(-1, 4) - window = Window( - ells=ells, - weights=weights, - ells_for_interpolation=ells_for_interpolation, - ) - return window + return ells, weights @pytest.fixture(name="harmonic_two_point_xy") @@ -283,27 +283,30 @@ def make_real_two_point_xy( @pytest.fixture(name="two_point_cwindow") def make_two_point_cwindow( - window_1: Window, harmonic_two_point_xy: TwoPointXY -) -> TwoPointCWindow: + window_1: tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]], + harmonic_two_point_xy: TwoPointXY, +) -> TwoPointHarmonic: """Generate a TwoPointCWindow object with 100 ells.""" - two_point = TwoPointCWindow(XY=harmonic_two_point_xy, window=window_1) + two_point = TwoPointHarmonic( + XY=harmonic_two_point_xy, ells=window_1[0], window=window_1[1] + ) return two_point @pytest.fixture(name="two_point_cell") def make_two_point_cell( harmonic_two_point_xy: TwoPointXY, -) -> TwoPointCells: +) -> TwoPointHarmonic: """Generate a TwoPointCWindow object with 100 ells.""" ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - return TwoPointCells(ells=ells, XY=harmonic_two_point_xy) + return TwoPointHarmonic(ells=ells, XY=harmonic_two_point_xy) -@pytest.fixture(name="two_point_xi_theta") -def make_two_point_xi_theta(real_two_point_xy: TwoPointXY) -> TwoPointXiTheta: +@pytest.fixture(name="two_point_real") +def make_two_point_real(real_two_point_xy: TwoPointXY) -> TwoPointReal: """Generate a TwoPointCWindow object with 100 ells.""" thetas = np.array(np.linspace(0, 100, 100), dtype=np.float64) - return TwoPointXiTheta(thetas=thetas, XY=real_two_point_xy) + return TwoPointReal(thetas=thetas, XY=real_two_point_xy) @pytest.fixture(name="cluster_sacc_data") diff --git a/tests/generators/test_inferred_galaxy_zdist.py b/tests/generators/test_inferred_galaxy_zdist.py index abbd5659..43cb3469 100644 --- a/tests/generators/test_inferred_galaxy_zdist.py +++ b/tests/generators/test_inferred_galaxy_zdist.py @@ -27,7 +27,7 @@ make_measurements, make_measurements_dict, ) -from firecrown.metadata.two_point_types import Galaxies, Clusters, CMB +from firecrown.metadata_types import Galaxies, Clusters, CMB from firecrown.utils import base_model_from_yaml, base_model_to_yaml diff --git a/tests/likelihood/gauss_family/statistic/source/test_source.py b/tests/likelihood/gauss_family/statistic/source/test_source.py index bf2146a9..d21f456e 100644 --- a/tests/likelihood/gauss_family/statistic/source/test_source.py +++ b/tests/likelihood/gauss_family/statistic/source/test_source.py @@ -19,7 +19,7 @@ ) import firecrown.likelihood.number_counts as nc import firecrown.likelihood.weak_lensing as wl -from firecrown.metadata.two_point import extract_all_tracers +from firecrown.metadata_functions import extract_all_tracers_inferred_galaxy_zdists from firecrown.parameters import ParamsMap @@ -178,7 +178,7 @@ def test_weak_lensing_source_init(sacc_galaxy_cells_src0_src0): def test_weak_lensing_source_create_ready(sacc_galaxy_cells_src0_src0): sacc_data, _, _ = sacc_galaxy_cells_src0_src0 - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) src0 = next((obj for obj in all_tracers if obj.bin_name == "src0"), None) assert src0 is not None @@ -194,7 +194,7 @@ def test_weak_lensing_source_create_ready(sacc_galaxy_cells_src0_src0): def test_weak_lensing_source_factory(sacc_galaxy_cells_src0_src0): sacc_data, _, _ = sacc_galaxy_cells_src0_src0 - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) src0 = next((obj for obj in all_tracers if obj.bin_name == "src0"), None) assert src0 is not None @@ -211,7 +211,7 @@ def test_weak_lensing_source_factory(sacc_galaxy_cells_src0_src0): def test_weak_lensing_source_factory_cache(sacc_galaxy_cells_src0_src0): sacc_data, _, _ = sacc_galaxy_cells_src0_src0 - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) src0 = next((obj for obj in all_tracers if obj.bin_name == "src0"), None) assert src0 is not None @@ -224,7 +224,7 @@ def test_weak_lensing_source_factory_cache(sacc_galaxy_cells_src0_src0): def test_weak_lensing_source_factory_global_systematics(sacc_galaxy_cells_src0_src0): sacc_data, _, _ = sacc_galaxy_cells_src0_src0 - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) src0 = next((obj for obj in all_tracers if obj.bin_name == "src0"), None) assert src0 is not None @@ -272,7 +272,7 @@ def test_number_counts_source_init(sacc_galaxy_cells_lens0_lens0): def test_number_counts_source_create_ready(sacc_galaxy_cells_lens0_lens0): sacc_data, _, _ = sacc_galaxy_cells_lens0_lens0 - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) lens0 = next((obj for obj in all_tracers if obj.bin_name == "lens0"), None) assert lens0 is not None @@ -288,7 +288,7 @@ def test_number_counts_source_create_ready(sacc_galaxy_cells_lens0_lens0): def test_number_counts_source_factory(sacc_galaxy_cells_lens0_lens0): sacc_data, _, _ = sacc_galaxy_cells_lens0_lens0 - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) lens0 = next((obj for obj in all_tracers if obj.bin_name == "lens0"), None) assert lens0 is not None @@ -305,7 +305,7 @@ def test_number_counts_source_factory(sacc_galaxy_cells_lens0_lens0): def test_number_counts_source_factory_cache(sacc_galaxy_cells_lens0_lens0): sacc_data, _, _ = sacc_galaxy_cells_lens0_lens0 - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) lens0 = next((obj for obj in all_tracers if obj.bin_name == "lens0"), None) assert lens0 is not None @@ -318,7 +318,7 @@ def test_number_counts_source_factory_cache(sacc_galaxy_cells_lens0_lens0): def test_number_counts_source_factory_global_systematics(sacc_galaxy_cells_lens0_lens0): sacc_data, _, _ = sacc_galaxy_cells_lens0_lens0 - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) lens0 = next((obj for obj in all_tracers if obj.bin_name == "lens0"), None) assert lens0 is not None diff --git a/tests/likelihood/gauss_family/statistic/test_two_point.py b/tests/likelihood/gauss_family/statistic/test_two_point.py index fb4aa524..f71fda8b 100644 --- a/tests/likelihood/gauss_family/statistic/test_two_point.py +++ b/tests/likelihood/gauss_family/statistic/test_two_point.py @@ -5,6 +5,7 @@ import re from unittest.mock import MagicMock import numpy as np +from numpy.testing import assert_allclose import pytest import pyccl @@ -18,6 +19,7 @@ from firecrown.likelihood.weak_lensing import ( WeakLensing, ) +from firecrown.likelihood.statistic import TheoryVector from firecrown.likelihood.two_point import ( _ell_for_xi, TwoPoint, @@ -25,20 +27,22 @@ TRACER_NAMES_TOTAL, EllOrThetaConfig, use_source_factory, - use_source_factory_metadata_only, + use_source_factory_metadata_index, WeakLensingFactory, NumberCountsFactory, ) -from firecrown.metadata.two_point_types import ( +from firecrown.metadata_types import ( Galaxies, InferredGalaxyZDist, + TwoPointHarmonic, GALAXY_LENS_TYPES, GALAXY_SOURCE_TYPES, ) -from firecrown.metadata.two_point import ( - TwoPointCellsIndex, - TwoPointXiThetaIndex, +from firecrown.metadata_functions import ( + TwoPointHarmonicIndex, + TwoPointRealIndex, ) +from firecrown.data_types import TwoPointMeasurement @pytest.fixture(name="source_0") @@ -53,6 +57,47 @@ def fixture_tools() -> ModelingTools: return ModelingTools() +@pytest.fixture(name="harmonic_data_with_window") +def fixture_harmonic_data_with_window(harmonic_two_point_xy) -> TwoPointMeasurement: + """Return some fake harmonic data.""" + ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) + # The window is given by the mean of the ells in each bin times the bin number. + weights = np.zeros((100, 4)) + weights[0:25, 0] = 1.0 / 25.0 + weights[25:50, 1] = 2.0 / 25.0 + weights[50:75, 2] = 3.0 / 25.0 + weights[75:100, 3] = 4.0 / 25.0 + + data = np.zeros(4) + 1.1 + indices = np.arange(4) + covariance_name = "cov" + tpm = TwoPointMeasurement( + data=data, + indices=indices, + covariance_name=covariance_name, + metadata=TwoPointHarmonic(ells=ells, window=weights, XY=harmonic_two_point_xy), + ) + + return tpm + + +@pytest.fixture(name="harmonic_data_no_window") +def fixture_harmonic_data_no_window(harmonic_two_point_xy) -> TwoPointMeasurement: + """Return some fake harmonic data.""" + ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) + data = np.zeros(100) - 1.1 + indices = np.arange(100) + covariance_name = "cov" + tpm = TwoPointMeasurement( + data=data, + indices=indices, + covariance_name=covariance_name, + metadata=TwoPointHarmonic(ells=ells, XY=harmonic_two_point_xy), + ) + + return tpm + + def test_ell_for_xi_no_rounding(): res = _ell_for_xi(minimum=0, midpoint=5, maximum=80, n_log=5) expected = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 40.0, 80.0]) @@ -107,8 +152,6 @@ def test_two_point_src0_src0_window(sacc_galaxy_cells_src0_src0_window): tools.prepare(pyccl.CosmologyVanillaLCDM()) assert statistic.window is not None - assert statistic.window.ells_for_interpolation is not None - assert all(np.isfinite(statistic.window.ells_for_interpolation)) statistic.reset() statistic.update(ParamsMap()) @@ -428,7 +471,7 @@ def test_use_source_factory_invalid_measurement(harmonic_bin_1: InferredGalaxyZD def test_use_source_factory_metadata_only_counts(): wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) - source = use_source_factory_metadata_only( + source = use_source_factory_metadata_index( "bin1", Galaxies.COUNTS, wl_factory=wl_factory, nc_factory=nc_factory ) assert isinstance(source, NumberCounts) @@ -437,7 +480,7 @@ def test_use_source_factory_metadata_only_counts(): def test_use_source_factory_metadata_only_shear(): wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) - source = use_source_factory_metadata_only( + source = use_source_factory_metadata_index( "bin1", Galaxies.SHEAR_E, wl_factory=wl_factory, nc_factory=nc_factory ) assert isinstance(source, WeakLensing) @@ -447,7 +490,7 @@ def test_use_source_factory_metadata_only_invalid_measurement(): wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) with pytest.raises(ValueError, match="Unknown measurement type encountered .*"): - use_source_factory_metadata_only( + use_source_factory_metadata_index( "bin1", 120, wl_factory=wl_factory, nc_factory=nc_factory # type: ignore ) @@ -463,10 +506,7 @@ def test_from_metadata_harmonic_wrong_metadata(): with pytest.raises( ValueError, match=re.escape("Metadata of type is not supported") ): - TwoPoint._from_metadata( # pylint: disable=protected-access - sacc_data_type="galaxy_density_xi", - source0=NumberCounts(sacc_tracer="lens_0"), - source1=NumberCounts(sacc_tracer="lens_0"), + TwoPoint._from_metadata_single( # pylint: disable=protected-access metadata="NotAMetadata", # type: ignore ) @@ -476,7 +516,7 @@ def test_use_source_factory_metadata_only_wrong_measurement(): unknown_type.configure_mock(__eq__=MagicMock(return_value=False)) with pytest.raises(ValueError, match="Measurement .* not supported!"): - use_source_factory_metadata_only( + use_source_factory_metadata_index( "bin1", unknown_type, wl_factory=None, nc_factory=None ) @@ -484,12 +524,11 @@ def test_use_source_factory_metadata_only_wrong_measurement(): def test_from_metadata_only_harmonic(): wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) - metadata: TwoPointCellsIndex = { + metadata: TwoPointHarmonicIndex = { "data_type": "galaxy_density_xi", "tracer_names": TracerNames("lens0", "lens0"), - "ells": np.array(np.linspace(0, 100, 100), dtype=np.int64), } - two_point = TwoPoint.from_metadata_only_harmonic( + two_point = TwoPoint.from_metadata_index( [metadata], wl_factory=wl_factory, nc_factory=nc_factory, @@ -501,15 +540,97 @@ def test_from_metadata_only_harmonic(): def test_from_metadata_only_real(): wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) - metadata: TwoPointXiThetaIndex = { + metadata: TwoPointRealIndex = { "data_type": "galaxy_shear_xi_plus", "tracer_names": TracerNames("src0", "src0"), - "thetas": np.linspace(0.0, 1.0, 100), } - two_point = TwoPoint.from_metadata_only_real( + two_point = TwoPoint.from_metadata_index( [metadata], wl_factory=wl_factory, nc_factory=nc_factory, ).pop() assert isinstance(two_point, TwoPoint) assert not two_point.ready + + +def test_from_measurement_compute_theory_vector_window( + harmonic_data_with_window: TwoPointMeasurement, +): + wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) + nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) + two_points = TwoPoint.from_measurement( + [harmonic_data_with_window], wl_factory=wl_factory, nc_factory=nc_factory + ) + two_point: TwoPoint = two_points.pop() + + assert isinstance(two_point, TwoPoint) + assert two_point.ready + + req_params = two_point.required_parameters() + default_values = req_params.get_default_values() + params = ParamsMap(default_values) + + tools = ModelingTools() + tools.update(params) + tools.prepare(pyccl.CosmologyVanillaLCDM()) + two_point.update(params) + + prediction = two_point.compute_theory_vector(tools) + + assert isinstance(prediction, TheoryVector) + assert prediction.shape == (4,) + + +def test_from_measurement_compute_theory_vector_window_check( + harmonic_data_with_window: TwoPointMeasurement, + harmonic_data_no_window: TwoPointMeasurement, +): + wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) + nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) + + two_points_with_window = TwoPoint.from_measurement( + [harmonic_data_with_window], wl_factory=wl_factory, nc_factory=nc_factory + ) + two_point_with_window: TwoPoint = two_points_with_window.pop() + + two_points_without_window = TwoPoint.from_measurement( + [harmonic_data_no_window], wl_factory=wl_factory, nc_factory=nc_factory + ) + two_point_without_window: TwoPoint = two_points_without_window.pop() + + assert isinstance(two_point_with_window, TwoPoint) + assert two_point_with_window.ready + + assert isinstance(two_point_without_window, TwoPoint) + assert two_point_without_window.ready + + req_params = two_point_with_window.required_parameters() + default_values = req_params.get_default_values() + params = ParamsMap(default_values) + + tools = ModelingTools() + tools.update(params) + tools.prepare(pyccl.CosmologyVanillaLCDM()) + + two_point_with_window.update(params) + two_point_without_window.update(params) + + prediction_with_window = two_point_with_window.compute_theory_vector(tools) + prediction_without_window = two_point_without_window.compute_theory_vector(tools) + + assert isinstance(prediction_with_window, TheoryVector) + assert prediction_with_window.shape == (4,) + + assert isinstance(prediction_without_window, TheoryVector) + assert prediction_without_window.shape == (100,) + # Currently the C0 and C1 are set to 0 when a window is present, so we need to do + # the same here. + prediction_without_window[0:2] = 0.0 + + binned_after = [ + np.mean(prediction_without_window[0:25]) * 1.0, + np.mean(prediction_without_window[25:50]) * 2.0, + np.mean(prediction_without_window[50:75]) * 3.0, + np.mean(prediction_without_window[75:100]) * 4.0, + ] + assert_allclose(prediction_with_window, binned_after) diff --git a/tests/likelihood/gauss_family/test_const_gaussian.py b/tests/likelihood/gauss_family/test_const_gaussian.py index 6736e86d..a5d6e007 100644 --- a/tests/likelihood/gauss_family/test_const_gaussian.py +++ b/tests/likelihood/gauss_family/test_const_gaussian.py @@ -26,11 +26,13 @@ WeakLensingFactory, NumberCountsFactory, ) -from firecrown.metadata.two_point import ( - extract_all_data_cells, - TwoPointCellsIndex, +from firecrown.metadata_types import ( TracerNames, ) +from firecrown.metadata_functions import ( + TwoPointHarmonicIndex, +) +from firecrown.data_functions import extract_all_harmonic_data class StatisticWithoutIndices(TrivialStatistic): @@ -483,13 +485,11 @@ def test_access_required_parameters( def test_create_ready(sacc_galaxy_cwindows): sacc_data, _, _ = sacc_galaxy_cwindows - _, two_point_cwindows = extract_all_data_cells(sacc_data) + two_point_harmonics = extract_all_harmonic_data(sacc_data) wl_factory = WeakLensingFactory(global_systematics=[], per_bin_systematics=[]) nc_factory = NumberCountsFactory(global_systematics=[], per_bin_systematics=[]) - two_points = TwoPoint.from_metadata_harmonic( - two_point_cwindows, wl_factory, nc_factory - ) + two_points = TwoPoint.from_measurement(two_point_harmonics, wl_factory, nc_factory) size = np.sum([len(two_point.get_data_vector()) for two_point in two_points]) likelihood = ConstGaussian.create_ready(two_points, np.diag(np.ones(size))) @@ -499,13 +499,11 @@ def test_create_ready(sacc_galaxy_cwindows): def test_create_ready_wrong_size(sacc_galaxy_cwindows): sacc_data, _, _ = sacc_galaxy_cwindows - _, two_point_cwindows = extract_all_data_cells(sacc_data) + two_point_harmonics = extract_all_harmonic_data(sacc_data) wl_factory = WeakLensingFactory(global_systematics=[], per_bin_systematics=[]) nc_factory = NumberCountsFactory(global_systematics=[], per_bin_systematics=[]) - two_points = TwoPoint.from_metadata_harmonic( - two_point_cwindows, wl_factory, nc_factory - ) + two_points = TwoPoint.from_measurement(two_point_harmonics, wl_factory, nc_factory) size = np.sum([len(two_point.get_data_vector()) for two_point in two_points]) with pytest.raises( @@ -522,15 +520,12 @@ def test_create_ready_not_ready(): wl_factory = WeakLensingFactory(global_systematics=[], per_bin_systematics=[]) nc_factory = NumberCountsFactory(global_systematics=[], per_bin_systematics=[]) - metadata: TwoPointCellsIndex = { + metadata: TwoPointHarmonicIndex = { "data_type": "galaxy_density_xi", "tracer_names": TracerNames("lens0", "lens0"), - "ells": np.array(np.linspace(0, 100, 100), dtype=np.int64), } - two_points = TwoPoint.from_metadata_only_harmonic( - [metadata], wl_factory, nc_factory - ) + two_points = TwoPoint.from_metadata_index([metadata], wl_factory, nc_factory) with pytest.raises( RuntimeError, diff --git a/tests/metadata/sacc_name_mapping.py b/tests/metadata/sacc_name_mapping.py index bf4059ec..eeb55e53 100644 --- a/tests/metadata/sacc_name_mapping.py +++ b/tests/metadata/sacc_name_mapping.py @@ -1,6 +1,6 @@ """Module containing testing data for translation of SACC names.""" -from firecrown.metadata.two_point_types import ( +from firecrown.metadata_types import ( Clusters, CMB, Galaxies, diff --git a/tests/metadata/test_metadata_two_point.py b/tests/metadata/test_metadata_two_point.py index d0357f0c..8c5cc092 100644 --- a/tests/metadata/test_metadata_two_point.py +++ b/tests/metadata/test_metadata_two_point.py @@ -1,30 +1,24 @@ """ -Tests for the module firecrown.metadata.two_point +Tests for the module firecrown.metadata_types and firecrown.metadata_functions. """ import pytest import numpy as np from numpy.testing import assert_array_equal -from firecrown.metadata.two_point_types import ( +from firecrown.metadata_types import ( ALL_MEASUREMENTS, type_to_sacc_string_harmonic as harmonic, type_to_sacc_string_real as real, -) - -from firecrown.metadata.two_point_types import ( Clusters, CMB, Galaxies, InferredGalaxyZDist, TracerNames, - TwoPointCells, - TwoPointCWindow, + TwoPointHarmonic, TwoPointXY, - TwoPointXiTheta, - TwoPointMeasurement, - Window, + TwoPointReal, ) - +from firecrown.data_types import TwoPointMeasurement from firecrown.likelihood.source import SourceGalaxy from firecrown.likelihood.two_point import TwoPoint @@ -186,7 +180,7 @@ def test_two_point_xy_invalid(): ) -def test_two_point_cells(): +def test_two_point_harmonic(): x = InferredGalaxyZDist( bin_name="bname1", z=np.linspace(0, 1, 100), @@ -203,14 +197,15 @@ def test_two_point_cells(): x=x, y=y, x_measurement=Galaxies.COUNTS, y_measurement=Galaxies.COUNTS ) ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - cells = TwoPointCells(ells=ells, XY=xy) + cells = TwoPointHarmonic(ells=ells, XY=xy) assert_array_equal(cells.ells, ells) assert cells.XY == xy assert cells.get_sacc_name() == harmonic(xy.x_measurement, xy.y_measurement) + assert cells.n_observations() == 100 -def test_two_point_cells_invalid_ells(): +def test_two_point_harmonic_invalid_ells(): x = InferredGalaxyZDist( bin_name="bname1", z=np.linspace(0, 1, 100), @@ -231,10 +226,10 @@ def test_two_point_cells_invalid_ells(): ValueError, match="Ells should be a 1D array.", ): - TwoPointCells(ells=ells, XY=xy) + TwoPointHarmonic(ells=ells, XY=xy) -def test_two_point_cells_invalid_type(): +def test_two_point_harmonic_invalid_type(): x = InferredGalaxyZDist( bin_name="bname1", z=np.linspace(0, 1, 100), @@ -255,206 +250,134 @@ def test_two_point_cells_invalid_type(): ValueError, match="Measurements .* and .* must support harmonic-space calculations.", ): - TwoPointCells(ells=ells, XY=xy) - - -def test_window_equality(): - # Construct two windows with the same parameters, but different objects. - window_1 = Window( - ells=np.array(np.linspace(0, 100, 100), dtype=np.int64), - weights=np.ones(400).reshape(-1, 4), - ells_for_interpolation=np.array(np.linspace(0, 100, 100), dtype=np.int64), - ) - window_2 = Window( - ells=np.array(np.linspace(0, 100, 100), dtype=np.int64), - weights=np.ones(400).reshape(-1, 4), - ells_for_interpolation=np.array(np.linspace(0, 100, 100), dtype=np.int64), - ) - # Two windows constructed from the same parameters should be equal. - assert window_1 == window_2 - - # If we get rid of the ells_for_interpolation from one, they should no - # longer be equal. - window_2.ells_for_interpolation = None - assert window_1 != window_2 - assert window_2 != window_1 - - # If we have nulled out both ells_for_interpolation, they should be equal - # again. - window_1.ells_for_interpolation = None - assert window_1 == window_2 + TwoPointHarmonic(ells=ells, XY=xy) - # And if we change the ells, they should no longer be equal. - window_2.ells[0] = window_2.ells[0] + 1 - assert window_1 != window_2 - -def test_two_point_window(): +def test_two_point_cwindow(harmonic_two_point_xy: TwoPointXY): ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - ells_for_interpolation = np.array(np.linspace(0, 100, 100), dtype=np.int64) - weights = np.ones(400).reshape(-1, 4) - - window = Window( - ells=ells, - weights=weights, - ells_for_interpolation=ells_for_interpolation, - ) - - assert_array_equal(window.ells, ells) - assert_array_equal(window.weights, weights) - assert window.ells_for_interpolation is not None - assert_array_equal(window.ells_for_interpolation, ells_for_interpolation) - assert window.n_observations() == 4 - - -def test_two_point_window_invalid_ells(): - ells = np.array(np.linspace(0, 100), dtype=np.int64).reshape(-1, 10) weights = np.ones(400).reshape(-1, 4) - ells_for_interpolation = np.array(np.linspace(0, 100, 100), dtype=np.int64) - with pytest.raises( - ValueError, - match="Ells should be a 1D array.", - ): - Window( - ells=ells, - weights=weights, - ells_for_interpolation=ells_for_interpolation, - ) - - -def test_two_point_window_invalid_weights(): ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - weights = np.ones(400).reshape(-1, 5) - ells_for_interpolation = np.array(np.linspace(0, 100, 100), dtype=np.int64) + two_point = TwoPointHarmonic(XY=harmonic_two_point_xy, ells=ells, window=weights) - with pytest.raises( - ValueError, - match="Weights should have the same number of rows as ells.", - ): - Window( - ells=ells, - weights=weights, - ells_for_interpolation=ells_for_interpolation, - ) + assert two_point.window is not None + assert_array_equal(two_point.window, weights) + assert two_point.XY == harmonic_two_point_xy + assert two_point.get_sacc_name() == harmonic( + harmonic_two_point_xy.x_measurement, harmonic_two_point_xy.y_measurement + ) + assert two_point.n_observations() == 4 -def test_two_point_window_invalid_ells_for_interpolation(): +def test_two_point_cwindow_wrong_data_shape( + harmonic_two_point_xy: TwoPointXY, +): ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) weights = np.ones(400).reshape(-1, 4) - ells_for_interpolation = np.array(np.linspace(0, 100), dtype=np.int64).reshape( - -1, 10 - ) + data = np.zeros(100) + 1.1 + indices = np.arange(100) + covariance_name = "cov" + data = data.reshape(-1, 10) with pytest.raises( ValueError, - match="Ells for interpolation should be a 1D array.", + match="Data should be a 1D array.", ): - Window( - ells=ells, - weights=weights, - ells_for_interpolation=ells_for_interpolation, + TwoPointMeasurement( + data=data, + indices=indices, + covariance_name=covariance_name, + metadata=TwoPointHarmonic( + XY=harmonic_two_point_xy, + ells=ells, + window=weights, + ), ) -def test_two_point_window_invalid_weights_shape(): - ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - weights = np.ones(400) - ells_for_interpolation = np.array(np.linspace(0, 100), dtype=np.int64) - +def test_two_point_measurement_invalid_metadata(): + data = np.zeros(100) + 1.1 + indices = np.arange(100) + covariance_name = "cov" with pytest.raises( ValueError, - match="Weights should be a 2D array.", + match="Metadata should be an instance of TwoPointReal or TwoPointHarmonic.", ): - Window( - ells=ells, - weights=weights, - ells_for_interpolation=ells_for_interpolation, + TwoPointMeasurement( + data=data, + indices=indices, + covariance_name=covariance_name, + metadata="Im not a metadata", # type: ignore ) -def test_two_point_two_point_cwindow(harmonic_two_point_xy: TwoPointXY): +def test_two_point_cwindow_stringify(harmonic_two_point_xy: TwoPointXY): ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - ells_for_interpolation = np.array(np.linspace(0, 100, 100), dtype=np.int64) weights = np.ones(400).reshape(-1, 4) - window = Window( - ells=ells, - weights=weights, - ells_for_interpolation=ells_for_interpolation, - ) - - ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - two_point = TwoPointCWindow(XY=harmonic_two_point_xy, window=window) + two_point = TwoPointHarmonic(XY=harmonic_two_point_xy, ells=ells, window=weights) - assert two_point.window == window - assert two_point.XY == harmonic_two_point_xy - assert two_point.get_sacc_name() == harmonic( - harmonic_two_point_xy.x_measurement, harmonic_two_point_xy.y_measurement + assert ( + str(two_point) == f"{str(harmonic_two_point_xy)}[{two_point.get_sacc_name()}]" ) -def test_two_point_two_point_cwindow_wrong_data_shape( - harmonic_two_point_xy: TwoPointXY, -): +def test_two_point_cwindow_invalid(): ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - ells_for_interpolation = np.array(np.linspace(0, 100, 100), dtype=np.int64) weights = np.ones(400).reshape(-1, 4) - window = Window( - ells=ells, - weights=weights, - ells_for_interpolation=ells_for_interpolation, + x = InferredGalaxyZDist( + bin_name="bname1", + z=np.linspace(0, 1, 100), + dndz=np.ones(100), + measurements={Galaxies.COUNTS}, ) - - data = np.zeros(100) + 1.1 - indices = np.arange(100) - covariance_name = "cov" - data = data.reshape(-1, 10) + y = InferredGalaxyZDist( + bin_name="bname2", + z=np.linspace(0, 1, 100), + dndz=np.ones(100), + measurements={Galaxies.SHEAR_T}, + ) + xy = TwoPointXY( + x=x, y=y, x_measurement=Galaxies.COUNTS, y_measurement=Galaxies.SHEAR_T + ) + ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) with pytest.raises( ValueError, - match="Data should be a 1D array.", + match="Measurements .* and .* must support harmonic-space calculations.", ): - TwoPointCWindow( - XY=harmonic_two_point_xy, - window=window, - Cell=TwoPointMeasurement( - data=data, - indices=indices, - covariance_name=covariance_name, - ), - ) + TwoPointHarmonic(XY=xy, ells=ells, window=weights) -def test_two_point_two_point_cwindow_stringify(harmonic_two_point_xy: TwoPointXY): - ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - ells_for_interpolation = np.array(np.linspace(0, 100, 100), dtype=np.int64) - weights = np.ones(400).reshape(-1, 4) - - window = Window( - ells=ells, - weights=weights, - ells_for_interpolation=ells_for_interpolation, +def test_two_point_cwindow_invalid_window(): + x = InferredGalaxyZDist( + bin_name="bname1", + z=np.linspace(0, 1, 100), + dndz=np.ones(100), + measurements={Galaxies.COUNTS}, ) - - two_point = TwoPointCWindow(XY=harmonic_two_point_xy, window=window) - - assert ( - str(two_point) == f"{str(harmonic_two_point_xy)}[{two_point.get_sacc_name()}]" + y = InferredGalaxyZDist( + bin_name="bname2", + z=np.linspace(0, 1, 100), + dndz=np.ones(100), + measurements={Galaxies.SHEAR_T}, + ) + xy = TwoPointXY( + x=x, y=y, x_measurement=Galaxies.COUNTS, y_measurement=Galaxies.SHEAR_T ) + with pytest.raises( + ValueError, + match="window should be a ndarray.", + ): + TwoPointHarmonic( + XY=xy, ells=np.array([1.0]), window="Im not a window" # type: ignore + ) -def test_two_point_two_point_cwindow_invalid(): +def test_two_point_cwindow_invalid_window_shape(): ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - ells_for_interpolation = np.array(np.linspace(0, 100, 100), dtype=np.int64) weights = np.ones(400).reshape(-1, 4) - window = Window( - ells=ells, - weights=weights, - ells_for_interpolation=ells_for_interpolation, - ) - x = InferredGalaxyZDist( bin_name="bname1", z=np.linspace(0, 1, 100), @@ -470,15 +393,18 @@ def test_two_point_two_point_cwindow_invalid(): xy = TwoPointXY( x=x, y=y, x_measurement=Galaxies.COUNTS, y_measurement=Galaxies.SHEAR_T ) - ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) + weights = np.ones(400) with pytest.raises( ValueError, - match="Measurements .* and .* must support harmonic-space calculations.", + match="window should be a 2D array.", ): - TwoPointCWindow(XY=xy, window=window) + TwoPointHarmonic(XY=xy, ells=ells, window=weights) + +def test_two_point_cwindow_window_ell_not_match(): + ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) + weights = np.ones(400).reshape(-1, 4) -def test_two_point_two_point_cwindow_invalid_window(): x = InferredGalaxyZDist( bin_name="bname1", z=np.linspace(0, 1, 100), @@ -494,14 +420,15 @@ def test_two_point_two_point_cwindow_invalid_window(): xy = TwoPointXY( x=x, y=y, x_measurement=Galaxies.COUNTS, y_measurement=Galaxies.SHEAR_T ) + weights = np.ones(400).reshape(-1, 4) with pytest.raises( ValueError, - match="Window should be a Window object.", + match="window should have the same number of rows as ells.", ): - TwoPointCWindow(XY=xy, window="Im not a window") # type: ignore + TwoPointHarmonic(XY=xy, ells=ells[:10], window=weights) -def test_two_point_xi_theta(): +def test_two_point_real(): x = InferredGalaxyZDist( bin_name="bname1", z=np.linspace(0, 1, 100), @@ -518,14 +445,15 @@ def test_two_point_xi_theta(): x=x, y=y, x_measurement=Galaxies.COUNTS, y_measurement=Galaxies.COUNTS ) theta = np.array(np.linspace(0, 100, 100)) - two_point = TwoPointXiTheta(XY=xy, thetas=theta) + two_point = TwoPointReal(XY=xy, thetas=theta) assert_array_equal(two_point.thetas, theta) assert two_point.XY == xy assert two_point.get_sacc_name() == real(xy.x_measurement, xy.y_measurement) + assert two_point.n_observations() == 100 -def test_two_point_xi_theta_invalid(): +def test_two_point_real_invalid(): x = InferredGalaxyZDist( bin_name="bname1", z=np.linspace(0, 1, 100), @@ -546,7 +474,7 @@ def test_two_point_xi_theta_invalid(): ValueError, match="Measurements .* and .* must support real-space calculations.", ): - TwoPointXiTheta(XY=xy, thetas=theta) + TwoPointReal(XY=xy, thetas=theta) def test_harmonic_type_string_invalid(): @@ -600,66 +528,82 @@ def test_two_point_xy_serialization(harmonic_two_point_xy: TwoPointXY): assert str(harmonic_two_point_xy) == str(recovered) -def test_two_point_cells_str(harmonic_two_point_xy: TwoPointXY): +def test_two_point_harmonic_str(harmonic_two_point_xy: TwoPointXY): ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - cells = TwoPointCells(ells=ells, XY=harmonic_two_point_xy) + cells = TwoPointHarmonic(ells=ells, XY=harmonic_two_point_xy) assert str(cells) == f"{str(harmonic_two_point_xy)}[{cells.get_sacc_name()}]" -def test_two_point_cells_serialization(harmonic_two_point_xy: TwoPointXY): +def test_two_point_harmonic_serialization(harmonic_two_point_xy: TwoPointXY): ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - cells = TwoPointCells(ells=ells, XY=harmonic_two_point_xy) + cells = TwoPointHarmonic(ells=ells, XY=harmonic_two_point_xy) s = cells.to_yaml() - recovered = TwoPointCells.from_yaml(s) + recovered = TwoPointHarmonic.from_yaml(s) assert cells == recovered assert str(harmonic_two_point_xy) == str(recovered.XY) assert str(cells) == str(recovered) -def test_two_point_cells_ells_wrong_shape(harmonic_two_point_xy: TwoPointXY): - ells = np.array(np.linspace(0, 100), dtype=np.int64).reshape(-1, 10) +def test_two_point_harmonic_cmp_invalid(harmonic_two_point_xy: TwoPointXY): + ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) + cells = TwoPointHarmonic(ells=ells, XY=harmonic_two_point_xy) with pytest.raises( ValueError, - match="Ells should be a 1D array.", + match="Can only compare TwoPointHarmonic objects.", ): - TwoPointCells(ells=ells, XY=harmonic_two_point_xy) + _ = cells == "Im not a TwoPointHarmonic" -def test_window_serialization(window_1: Window): - s = window_1.to_yaml() - recovered = Window.from_yaml(s) - assert window_1 == recovered +def test_two_point_harmonic_ells_wrong_shape(harmonic_two_point_xy: TwoPointXY): + ells = np.array(np.linspace(0, 100), dtype=np.int64).reshape(-1, 10) + with pytest.raises( + ValueError, + match="Ells should be a 1D array.", + ): + TwoPointHarmonic(ells=ells, XY=harmonic_two_point_xy) -def test_two_point_cwindow_serialization(two_point_cwindow: TwoPointCWindow): +def test_two_point_harmonic_with_window_serialization( + two_point_cwindow: TwoPointHarmonic, +): s = two_point_cwindow.to_yaml() - recovered = TwoPointCWindow.from_yaml(s) + recovered = TwoPointHarmonic.from_yaml(s) assert two_point_cwindow == recovered -def test_two_point_xi_theta_serialization(real_two_point_xy: TwoPointXY): +def test_two_point_real_serialization(real_two_point_xy: TwoPointXY): theta = np.array(np.linspace(0, 10, 10)) - xi_theta = TwoPointXiTheta(XY=real_two_point_xy, thetas=theta) + xi_theta = TwoPointReal(XY=real_two_point_xy, thetas=theta) s = xi_theta.to_yaml() - recovered = TwoPointXiTheta.from_yaml(s) + recovered = TwoPointReal.from_yaml(s) assert xi_theta == recovered assert str(real_two_point_xy) == str(recovered.XY) assert str(xi_theta) == str(recovered) -def test_two_point_xi_theta_wrong_shape(real_two_point_xy: TwoPointXY): +def test_two_point_real_cmp_invalid(real_two_point_xy: TwoPointXY): + theta = np.array(np.linspace(0, 10, 10)) + xi_theta = TwoPointReal(XY=real_two_point_xy, thetas=theta) + with pytest.raises( + ValueError, + match="Can only compare TwoPointReal objects.", + ): + _ = xi_theta == "Im not a TwoPointReal" + + +def test_two_point_real_wrong_shape(real_two_point_xy: TwoPointXY): theta = np.array(np.linspace(0, 10), dtype=np.float64).reshape(-1, 10) with pytest.raises( ValueError, match="Thetas should be a 1D array.", ): - TwoPointXiTheta(XY=real_two_point_xy, thetas=theta) + TwoPointReal(XY=real_two_point_xy, thetas=theta) def test_two_point_from_metadata_cells(harmonic_two_point_xy, wl_factory, nc_factory): ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - cells = TwoPointCells(ells=ells, XY=harmonic_two_point_xy) - two_point = TwoPoint.from_metadata_harmonic([cells], wl_factory, nc_factory).pop() + cells = TwoPointHarmonic(ells=ells, XY=harmonic_two_point_xy) + two_point = TwoPoint.from_metadata([cells], wl_factory, nc_factory).pop() assert two_point is not None assert isinstance(two_point, TwoPoint) @@ -676,7 +620,7 @@ def test_two_point_from_metadata_cells(harmonic_two_point_xy, wl_factory, nc_fac def test_two_point_from_metadata_cwindow(two_point_cwindow, wl_factory, nc_factory): - two_point = TwoPoint.from_metadata_harmonic( + two_point = TwoPoint.from_metadata( [two_point_cwindow], wl_factory, nc_factory ).pop() @@ -696,10 +640,10 @@ def test_two_point_from_metadata_cwindow(two_point_cwindow, wl_factory, nc_facto def test_two_point_from_metadata_xi_theta(real_two_point_xy, wl_factory, nc_factory): theta = np.array(np.linspace(0, 100, 100)) - xi_theta = TwoPointXiTheta(XY=real_two_point_xy, thetas=theta) + xi_theta = TwoPointReal(XY=real_two_point_xy, thetas=theta) if xi_theta.get_sacc_name() == "galaxy_shear_xi_tt": return - two_point = TwoPoint.from_metadata_real([xi_theta], wl_factory, nc_factory).pop() + two_point = TwoPoint.from_metadata([xi_theta], wl_factory, nc_factory).pop() assert two_point is not None assert isinstance(two_point, TwoPoint) @@ -732,9 +676,9 @@ def test_two_point_from_metadata_cells_unsupported_type(wl_factory, nc_factory): xy = TwoPointXY( x=x, y=y, x_measurement=CMB.CONVERGENCE, y_measurement=Galaxies.COUNTS ) - cells = TwoPointCells(ells=ells, XY=xy) + cells = TwoPointHarmonic(ells=ells, XY=xy) with pytest.raises( ValueError, match="Measurement .* not supported!", ): - TwoPoint.from_metadata_harmonic([cells], wl_factory, nc_factory) + TwoPoint.from_metadata([cells], wl_factory, nc_factory) diff --git a/tests/metadata/test_metadata_two_point_measurement.py b/tests/metadata/test_metadata_two_point_measurement.py index b64b8d1f..c655df51 100644 --- a/tests/metadata/test_metadata_two_point_measurement.py +++ b/tests/metadata/test_metadata_two_point_measurement.py @@ -1,23 +1,19 @@ """ -Tests for the module firecrown.metadata.two_point +Tests for the module firecrown.metadata_types and firecrown.metadata_functions. """ import pytest import numpy as np from numpy.testing import assert_array_equal -from firecrown.metadata.two_point_types import ( +from firecrown.metadata_types import ( + TwoPointHarmonic, + TwoPointXY, + TwoPointReal, type_to_sacc_string_harmonic as harmonic, type_to_sacc_string_real as real, ) -from firecrown.metadata.two_point import ( - TwoPointCells, - TwoPointCWindow, - TwoPointXiTheta, - TwoPointXY, - TwoPointMeasurement, - Window, -) +from firecrown.data_types import TwoPointMeasurement def test_two_point_cells_with_data(harmonic_two_point_xy: TwoPointXY): @@ -25,74 +21,73 @@ def test_two_point_cells_with_data(harmonic_two_point_xy: TwoPointXY): data = np.ones_like(ells) * 1.1 indices = np.arange(100) covariance_name = "cov" - measure = TwoPointMeasurement( - data=data, indices=indices, covariance_name=covariance_name + tpm = TwoPointMeasurement( + data=data, + indices=indices, + covariance_name=covariance_name, + metadata=TwoPointHarmonic(ells=ells, XY=harmonic_two_point_xy), ) + metadata = tpm.metadata + assert isinstance(metadata, TwoPointHarmonic) - cells = TwoPointCells(ells=ells, XY=harmonic_two_point_xy, Cell=measure) - - assert cells.ells[0] == 0 - assert cells.ells[-1] == 100 - assert cells.Cell is not None - assert cells.has_data() - assert_array_equal(cells.Cell.data, data) - assert_array_equal(cells.Cell.indices, indices) - assert cells.Cell.covariance_name == covariance_name + assert metadata.ells[0] == 0 + assert metadata.ells[-1] == 100 + assert tpm.data is not None + assert_array_equal(tpm.data, data) + assert_array_equal(tpm.indices, indices) + assert tpm.covariance_name == covariance_name def test_two_point_two_point_cwindow_with_data(harmonic_two_point_xy: TwoPointXY): ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - ells_for_interpolation = np.array(np.linspace(0, 100, 100), dtype=np.int64) weights = np.ones(400).reshape(-1, 4) - window = Window( - ells=ells, - weights=weights, - ells_for_interpolation=ells_for_interpolation, - ) - ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) data = np.zeros(4) + 1.1 indices = np.arange(4) covariance_name = "cov" - measure = TwoPointMeasurement( - data=data, indices=indices, covariance_name=covariance_name + tpm = TwoPointMeasurement( + data=data, + indices=indices, + covariance_name=covariance_name, + metadata=TwoPointHarmonic(ells=ells, window=weights, XY=harmonic_two_point_xy), ) + metadata = tpm.metadata + assert isinstance(metadata, TwoPointHarmonic) - two_point = TwoPointCWindow(XY=harmonic_two_point_xy, window=window, Cell=measure) - - assert two_point.window == window - assert two_point.XY == harmonic_two_point_xy - assert two_point.get_sacc_name() == harmonic( + assert metadata.window is not None + assert_array_equal(metadata.window, weights) + assert metadata.XY == harmonic_two_point_xy + assert metadata.get_sacc_name() == harmonic( harmonic_two_point_xy.x_measurement, harmonic_two_point_xy.y_measurement ) - assert two_point.has_data() - assert two_point.Cell is not None - assert_array_equal(two_point.Cell.data, data) - assert_array_equal(two_point.Cell.indices, indices) - assert two_point.Cell.covariance_name == covariance_name + assert tpm.data is not None + assert_array_equal(tpm.data, data) + assert_array_equal(tpm.indices, indices) + assert tpm.covariance_name == covariance_name def test_two_point_xi_theta_with_data(real_two_point_xy: TwoPointXY): + thetas = np.linspace(0.0, 1.0, 100) data = np.zeros(100) + 1.1 indices = np.arange(100) covariance_name = "cov" - measure = TwoPointMeasurement( - data=data, indices=indices, covariance_name=covariance_name + tpm = TwoPointMeasurement( + data=data, + indices=indices, + covariance_name=covariance_name, + metadata=TwoPointReal(XY=real_two_point_xy, thetas=thetas), ) - thetas = np.linspace(0.0, 1.0, 100) + metadata = tpm.metadata + assert isinstance(metadata, TwoPointReal) - xi_theta = TwoPointXiTheta(XY=real_two_point_xy, thetas=thetas, xis=measure) - - assert xi_theta.XY == real_two_point_xy - assert xi_theta.get_sacc_name() == real( + assert metadata.XY == real_two_point_xy + assert metadata.get_sacc_name() == real( real_two_point_xy.x_measurement, real_two_point_xy.y_measurement ) - assert xi_theta.has_data() - assert xi_theta.xis is not None - assert_array_equal(xi_theta.xis.data, data) - assert_array_equal(xi_theta.xis.indices, indices) - assert xi_theta.xis.covariance_name == covariance_name + assert_array_equal(tpm.data, data) + assert_array_equal(tpm.indices, indices) + assert tpm.covariance_name == covariance_name def test_two_point_cells_with_invalid_data_size(harmonic_two_point_xy: TwoPointXY): @@ -100,62 +95,62 @@ def test_two_point_cells_with_invalid_data_size(harmonic_two_point_xy: TwoPointX data = np.zeros(101) + 1.1 indices = np.arange(101) covariance_name = "cov" - measure = TwoPointMeasurement( - data=data, indices=indices, covariance_name=covariance_name - ) with pytest.raises( ValueError, - match="Cell should have the same shape as ells.", + match="Data and metadata should have the same length.", ): - TwoPointCells(ells=ells, XY=harmonic_two_point_xy, Cell=measure) + TwoPointMeasurement( + data=data, + indices=indices, + covariance_name=covariance_name, + metadata=TwoPointHarmonic(ells=ells, XY=harmonic_two_point_xy), + ) def test_two_point_cwindow_with_invalid_data_size(harmonic_two_point_xy: TwoPointXY): ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - ells_for_interpolation = np.array(np.linspace(0, 100, 100), dtype=np.int64) weights = np.ones(400).reshape(-1, 4) - window = Window( - ells=ells, - weights=weights, - ells_for_interpolation=ells_for_interpolation, - ) ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) data = np.zeros(5) + 1.1 indices = np.arange(5) covariance_name = "cov" - measure = TwoPointMeasurement( - data=data, indices=indices, covariance_name=covariance_name - ) with pytest.raises( ValueError, - match=( - "Data should have the same number of elements as the number " - "of observations supported by the window function." - ), + match="Data and metadata should have the same length.", ): - TwoPointCWindow(XY=harmonic_two_point_xy, window=window, Cell=measure) + TwoPointMeasurement( + data=data, + indices=indices, + covariance_name=covariance_name, + metadata=TwoPointHarmonic( + XY=harmonic_two_point_xy, ells=ells, window=weights + ), + ) def test_two_point_xi_theta_with_invalid_data_size(real_two_point_xy: TwoPointXY): + thetas = np.linspace(0.0, 1.0, 100) data = np.zeros(101) + 1.1 indices = np.arange(101) covariance_name = "cov" - measure = TwoPointMeasurement( - data=data, indices=indices, covariance_name=covariance_name - ) - thetas = np.linspace(0.0, 1.0, 100) with pytest.raises( ValueError, - match="Xis should have the same shape as thetas.", + match="Data and metadata should have the same length.", ): - TwoPointXiTheta(XY=real_two_point_xy, thetas=thetas, xis=measure) + TwoPointMeasurement( + data=data, + indices=indices, + covariance_name=covariance_name, + metadata=TwoPointReal(XY=real_two_point_xy, thetas=thetas), + ) -def test_two_point_measurement_invalid_data(): +def test_two_point_measurement_invalid_data(real_two_point_xy: TwoPointXY): + thetas = np.linspace(0.0, 1.0, 5) data = np.array([1, 2, 3, 4, 5]).reshape(-1, 1) indices = np.array([1, 2, 3, 4, 5]) covariance_name = "cov" @@ -163,10 +158,16 @@ def test_two_point_measurement_invalid_data(): ValueError, match="Data should be a 1D array.", ): - TwoPointMeasurement(data=data, indices=indices, covariance_name=covariance_name) + TwoPointMeasurement( + data=data, + indices=indices, + covariance_name=covariance_name, + metadata=TwoPointReal(XY=real_two_point_xy, thetas=thetas), + ) -def test_two_point_measurement_invalid_indices(): +def test_two_point_measurement_invalid_indices(real_two_point_xy: TwoPointXY): + thetas = np.linspace(0.0, 1.0, 5) data = np.array([1, 2, 3, 4, 5]) indices = np.array([1, 2, 3, 4, 5]).reshape(-1, 1) covariance_name = "cov" @@ -174,23 +175,36 @@ def test_two_point_measurement_invalid_indices(): ValueError, match="Data and indices should have the same shape.", ): - TwoPointMeasurement(data=data, indices=indices, covariance_name=covariance_name) + TwoPointMeasurement( + data=data, + indices=indices, + covariance_name=covariance_name, + metadata=TwoPointReal(XY=real_two_point_xy, thetas=thetas), + ) -def test_two_point_measurement_eq(): +def test_two_point_measurement_eq(real_two_point_xy: TwoPointXY): + thetas = np.linspace(0.0, 1.0, 5) data = np.array([1, 2, 3, 4, 5]) indices = np.array([1, 2, 3, 4, 5]) covariance_name = "cov" measure_1 = TwoPointMeasurement( - data=data, indices=indices, covariance_name=covariance_name + data=data, + indices=indices, + covariance_name=covariance_name, + metadata=TwoPointReal(XY=real_two_point_xy, thetas=thetas), ) measure_2 = TwoPointMeasurement( - data=data, indices=indices, covariance_name=covariance_name + data=data, + indices=indices, + covariance_name=covariance_name, + metadata=TwoPointReal(XY=real_two_point_xy, thetas=thetas), ) assert measure_1 == measure_2 -def test_two_point_measurement_neq(): +def test_two_point_measurement_neq(real_two_point_xy: TwoPointXY): + thetas = np.linspace(0.0, 1.0, 5) data = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) indices = np.array([1, 2, 3, 4, 5]) covariance_name = "cov" @@ -198,20 +212,38 @@ def test_two_point_measurement_neq(): data=data, indices=indices, covariance_name=covariance_name, + metadata=TwoPointReal(XY=real_two_point_xy, thetas=thetas), + ) + measure_2 = TwoPointMeasurement( + data=data, + indices=indices, + covariance_name="cov2", + metadata=TwoPointReal(XY=real_two_point_xy, thetas=thetas), ) - measure_2 = TwoPointMeasurement(data=data, indices=indices, covariance_name="cov2") assert measure_1 != measure_2 measure_3 = TwoPointMeasurement( - data=data, indices=indices, covariance_name=covariance_name + data=data, + indices=indices, + covariance_name=covariance_name, + metadata=TwoPointReal(XY=real_two_point_xy, thetas=thetas), ) measure_4 = TwoPointMeasurement( - data=data, indices=indices + 1, covariance_name=covariance_name + data=data, + indices=indices + 1, + covariance_name=covariance_name, + metadata=TwoPointReal(XY=real_two_point_xy, thetas=thetas), ) assert measure_3 != measure_4 measure_5 = TwoPointMeasurement( - data=data, indices=indices, covariance_name=covariance_name + data=data, + indices=indices, + covariance_name=covariance_name, + metadata=TwoPointReal(XY=real_two_point_xy, thetas=thetas), ) measure_6 = TwoPointMeasurement( - data=data + 1.0, indices=indices, covariance_name=covariance_name + data=data + 1.0, + indices=indices, + covariance_name=covariance_name, + metadata=TwoPointReal(XY=real_two_point_xy, thetas=thetas), ) assert measure_5 != measure_6 diff --git a/tests/metadata/test_metadata_two_point_sacc.py b/tests/metadata/test_metadata_two_point_sacc.py index b220f4bc..6761ca14 100644 --- a/tests/metadata/test_metadata_two_point_sacc.py +++ b/tests/metadata/test_metadata_two_point_sacc.py @@ -1,4 +1,4 @@ -"""Tests for the module firecrown.metadata.two_point. +"""Tests for the modules firecrown.metata_types and firecrown.metadata_functions. In this module, we test the functions and classes involved SACC extraction tools. """ @@ -11,24 +11,28 @@ import sacc from firecrown.parameters import ParamsMap -from firecrown.metadata.two_point_types import ( +from firecrown.metadata_types import ( Galaxies, + TracerNames, + TwoPointHarmonic, + TwoPointReal, type_to_sacc_string_harmonic, type_to_sacc_string_real, ) -from firecrown.metadata.two_point import ( - extract_all_data_types_cells, - extract_all_data_types_xi_thetas, +from firecrown.metadata_functions import ( + extract_all_harmonic_metadata_indices, + extract_all_harmonic_metadata, extract_all_photoz_bin_combinations, - extract_all_tracers, + extract_all_real_metadata_indices, + extract_all_real_metadata, + extract_all_tracers_inferred_galaxy_zdists, extract_window_function, - extract_all_data_cells, - extract_all_data_xi_thetas, +) +from firecrown.data_functions import ( check_two_point_consistence_harmonic, check_two_point_consistence_real, - TracerNames, - TwoPointCells, - TwoPointXiTheta, + extract_all_harmonic_data, + extract_all_real_data, ) from firecrown.likelihood.two_point import TwoPoint, use_source_factory @@ -213,7 +217,7 @@ def fixture_sacc_galaxy_cells_three_tracers(): def test_extract_all_tracers_cells(sacc_galaxy_cells): sacc_data, tracers, _ = sacc_galaxy_cells assert sacc_data is not None - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) for tracer in all_tracers: orig_tracer = tracers[tracer.bin_name] @@ -224,7 +228,7 @@ def test_extract_all_tracers_cells(sacc_galaxy_cells): def test_extract_all_tracers_xis(sacc_galaxy_xis): sacc_data, tracers, _ = sacc_galaxy_xis assert sacc_data is not None - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) for tracer in all_tracers: orig_tracer = tracers[tracer.bin_name] @@ -235,7 +239,7 @@ def test_extract_all_tracers_xis(sacc_galaxy_xis): def test_extract_all_tracers_cells_src0_src0(sacc_galaxy_cells_src0_src0): sacc_data, z, dndz = sacc_galaxy_cells_src0_src0 assert sacc_data is not None - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) assert len(all_tracers) == 1 @@ -249,7 +253,7 @@ def test_extract_all_tracers_cells_src0_src0(sacc_galaxy_cells_src0_src0): def test_extract_all_tracers_cells_src0_src1(sacc_galaxy_cells_src0_src1): sacc_data, z, dndz0, dndz1 = sacc_galaxy_cells_src0_src1 assert sacc_data is not None - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) assert len(all_tracers) == 2 @@ -265,7 +269,7 @@ def test_extract_all_tracers_cells_src0_src1(sacc_galaxy_cells_src0_src1): def test_extract_all_tracers_cells_lens0_lens0(sacc_galaxy_cells_lens0_lens0): sacc_data, z, dndz = sacc_galaxy_cells_lens0_lens0 assert sacc_data is not None - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) assert len(all_tracers) == 1 @@ -279,7 +283,7 @@ def test_extract_all_tracers_cells_lens0_lens0(sacc_galaxy_cells_lens0_lens0): def test_extract_all_tracers_cells_lens0_lens1(sacc_galaxy_cells_lens0_lens1): sacc_data, z, dndz0, dndz1 = sacc_galaxy_cells_lens0_lens1 assert sacc_data is not None - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) assert len(all_tracers) == 2 @@ -295,7 +299,7 @@ def test_extract_all_tracers_cells_lens0_lens1(sacc_galaxy_cells_lens0_lens1): def test_extract_all_tracers_xis_lens0_lens0(sacc_galaxy_xis_lens0_lens0): sacc_data, z, dndz = sacc_galaxy_xis_lens0_lens0 assert sacc_data is not None - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) assert len(all_tracers) == 1 @@ -309,7 +313,7 @@ def test_extract_all_tracers_xis_lens0_lens0(sacc_galaxy_xis_lens0_lens0): def test_extract_all_tracers_xis_lens0_lens1(sacc_galaxy_xis_lens0_lens1): sacc_data, z, dndz0, dndz1 = sacc_galaxy_xis_lens0_lens1 assert sacc_data is not None - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) assert len(all_tracers) == 2 @@ -325,7 +329,7 @@ def test_extract_all_tracers_xis_lens0_lens1(sacc_galaxy_xis_lens0_lens1): def test_extract_all_trace_cells_src0_lens0(sacc_galaxy_cells_src0_lens0): sacc_data, z, dndz0, dndz1 = sacc_galaxy_cells_src0_lens0 assert sacc_data is not None - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) assert len(all_tracers) == 2 @@ -343,7 +347,7 @@ def test_extract_all_trace_cells_src0_lens0(sacc_galaxy_cells_src0_lens0): def test_extract_all_trace_xis_src0_lens0(sacc_galaxy_xis_src0_lens0): sacc_data, z, dndz0, dndz1 = sacc_galaxy_xis_src0_lens0 assert sacc_data is not None - all_tracers = extract_all_tracers(sacc_data) + all_tracers = extract_all_tracers_inferred_galaxy_zdists(sacc_data) assert len(all_tracers) == 2 @@ -366,7 +370,7 @@ def test_extract_all_tracers_invalid_data_type( with pytest.raises( ValueError, match="Tracer src0 does not have data points associated with it." ): - _ = extract_all_tracers(sacc_data) + _ = extract_all_tracers_inferred_galaxy_zdists(sacc_data) def test_extract_all_tracers_bad_lens_label( @@ -378,7 +382,7 @@ def test_extract_all_tracers_bad_lens_label( ValueError, match="Tracer src0 does not have data points associated with it.", ): - _ = extract_all_tracers(sacc_data) + _ = extract_all_tracers_inferred_galaxy_zdists(sacc_data) def test_extract_all_tracers_bad_source_label( @@ -392,7 +396,7 @@ def test_extract_all_tracers_bad_source_label( "Tracer non_informative_label does not have data points associated with it." ), ): - _ = extract_all_tracers(sacc_data) + _ = extract_all_tracers_inferred_galaxy_zdists(sacc_data) def test_extract_all_tracers_inconsistent_lens_label( @@ -404,7 +408,7 @@ def test_extract_all_tracers_inconsistent_lens_label( ValueError, match=("Invalid SACC file, tracer names do not respect the naming convetion."), ): - _ = extract_all_tracers(sacc_data) + _ = extract_all_tracers_inferred_galaxy_zdists(sacc_data) def test_extract_all_tracers_inconsistent_source_label( @@ -416,27 +420,24 @@ def test_extract_all_tracers_inconsistent_source_label( ValueError, match=("Invalid SACC file, tracer names do not respect the naming convetion."), ): - _ = extract_all_tracers(sacc_data) + _ = extract_all_tracers_inferred_galaxy_zdists(sacc_data) -def test_extract_all_data_cells(sacc_galaxy_cells): +def test_extract_all_metadata_index_harmonics(sacc_galaxy_cells): sacc_data, _, tracer_pairs = sacc_galaxy_cells - all_data = extract_all_data_types_cells(sacc_data) + all_data = extract_all_harmonic_metadata_indices(sacc_data) assert len(all_data) == len(tracer_pairs) for two_point in all_data: tracer_names = two_point["tracer_names"] assert (tracer_names, two_point["data_type"]) in tracer_pairs - tracer_pair = tracer_pairs[(tracer_names, two_point["data_type"])] - assert_array_equal(two_point["ells"], tracer_pair[0]) - -def test_extract_all_data_cells_by_type(sacc_galaxy_cells): +def test_extract_all_metadata_index_harmonics_by_type(sacc_galaxy_cells): sacc_data, _, tracer_pairs = sacc_galaxy_cells - all_data = extract_all_data_types_cells( + all_data = extract_all_harmonic_metadata_indices( sacc_data, allowed_data_type=["galaxy_shear_cl_ee"] ) assert len(all_data) < len(tracer_pairs) @@ -445,28 +446,22 @@ def test_extract_all_data_cells_by_type(sacc_galaxy_cells): tracer_names = two_point["tracer_names"] assert (tracer_names, "galaxy_shear_cl_ee") in tracer_pairs - tracer_pair = tracer_pairs[(tracer_names, "galaxy_shear_cl_ee")] - assert_array_equal(two_point["ells"], tracer_pair[0]) - -def test_extract_all_data_xis(sacc_galaxy_xis): +def test_extract_all_metadata_index_reals(sacc_galaxy_xis): sacc_data, _, tracer_pairs = sacc_galaxy_xis - all_data = extract_all_data_types_xi_thetas(sacc_data) + all_data = extract_all_real_metadata_indices(sacc_data) assert len(all_data) == len(tracer_pairs) for two_point in all_data: tracer_names = two_point["tracer_names"] assert (tracer_names, two_point["data_type"]) in tracer_pairs - tracer_pair = tracer_pairs[(tracer_names, two_point["data_type"])] - assert_array_equal(two_point["thetas"], tracer_pair[0]) - -def test_extract_all_data_xis_by_type(sacc_galaxy_xis): +def test_extract_all_metadata_index_reals_by_type(sacc_galaxy_xis): sacc_data, _, tracer_pairs = sacc_galaxy_xis - all_data = extract_all_data_types_xi_thetas( + all_data = extract_all_real_metadata_indices( sacc_data, allowed_data_type=["galaxy_density_xi"] ) assert len(all_data) < len(tracer_pairs) @@ -475,19 +470,18 @@ def test_extract_all_data_xis_by_type(sacc_galaxy_xis): tracer_names = two_point["tracer_names"] assert (tracer_names, two_point["data_type"]) in tracer_pairs - tracer_pair = tracer_pairs[(tracer_names, two_point["data_type"])] - assert_array_equal(two_point["thetas"], tracer_pair[0]) - def test_extract_no_window(sacc_galaxy_cells_src0_src0): sacc_data, _, _ = sacc_galaxy_cells_src0_src0 indices = np.array([0, 1, 2], dtype=np.int64) with pytest.warns(UserWarning): - assert extract_window_function(sacc_data, indices=indices) is None + window = extract_window_function(sacc_data, indices=indices) + assert window[0] is None + assert window[1] is None -def test_extract_all_data_types_three_tracers_xis(sacc_galaxy_xis_three_tracers): +def test_extract_all_data_types_three_tracers_reals(sacc_galaxy_xis_three_tracers): sacc_data, _, _, _, _ = sacc_galaxy_xis_three_tracers with pytest.raises( @@ -497,10 +491,12 @@ def test_extract_all_data_types_three_tracers_xis(sacc_galaxy_xis_three_tracers) "does not have exactly two tracers." ), ): - _ = extract_all_data_types_xi_thetas(sacc_data) + _ = extract_all_real_metadata_indices(sacc_data) -def test_extract_all_data_types_three_tracers_cells(sacc_galaxy_cells_three_tracers): +def test_extract_all_data_types_three_tracers_harmonics( + sacc_galaxy_cells_three_tracers, +): sacc_data, _, _, _, _ = sacc_galaxy_cells_three_tracers with pytest.raises( @@ -510,10 +506,10 @@ def test_extract_all_data_types_three_tracers_cells(sacc_galaxy_cells_three_trac "does not have exactly two tracers." ), ): - _ = extract_all_data_types_cells(sacc_data) + _ = extract_all_harmonic_metadata_indices(sacc_data) -def test_extract_all_photoz_bin_combinations_xis(sacc_galaxy_xis): +def test_extract_all_photoz_bin_combinations_reals(sacc_galaxy_xis): sacc_data, _, tracer_pairs = sacc_galaxy_xis # We build all possible combinations of tracers all_bin_combs = extract_all_photoz_bin_combinations(sacc_data) @@ -532,13 +528,13 @@ def test_extract_all_photoz_bin_combinations_xis(sacc_galaxy_xis): assert tracer_names_type in tracer_names_list bin_comb = all_bin_combs[tracer_names_list.index(tracer_names_type)] - two_point_xis = TwoPointXiTheta( + two_point_xis = TwoPointReal( XY=bin_comb, thetas=np.linspace(0.0, 2.0 * np.pi, 20) ) assert two_point_xis.get_sacc_name() == tracer_names_type[1] -def test_extract_all_photoz_bin_combinations_cells(sacc_galaxy_cells): +def test_extract_all_photoz_bin_combinations_harmonics(sacc_galaxy_cells): sacc_data, _, tracer_pairs = sacc_galaxy_cells # We build all possible combinations of tracers all_bin_combs = extract_all_photoz_bin_combinations(sacc_data) @@ -559,20 +555,20 @@ def test_extract_all_photoz_bin_combinations_cells(sacc_galaxy_cells): assert tracer_names_type in tracer_names_list bin_comb = all_bin_combs[tracer_names_list.index(tracer_names_type)] - two_point_cells = TwoPointCells( + two_point_cells = TwoPointHarmonic( XY=bin_comb, ells=np.unique(np.logspace(1, 3, 10)) ) assert two_point_cells.get_sacc_name() == tracer_names_type[1] -def test_make_cells(sacc_galaxy_cells): +def test_make_harmonics(sacc_galaxy_cells): sacc_data, _, _ = sacc_galaxy_cells ells = np.unique(np.logspace(1, 3, 10).astype(np.int64)) all_bin_combs = extract_all_photoz_bin_combinations(sacc_data) for bin_comb in all_bin_combs: - two_point_cells = TwoPointCells(XY=bin_comb, ells=ells) + two_point_cells = TwoPointHarmonic(XY=bin_comb, ells=ells) assert two_point_cells.ells is not None assert_array_equal(two_point_cells.ells, ells) @@ -580,14 +576,14 @@ def test_make_cells(sacc_galaxy_cells): assert two_point_cells.XY == bin_comb -def test_make_xis(sacc_galaxy_xis): +def test_make_reals(sacc_galaxy_xis): sacc_data, _, _ = sacc_galaxy_xis thetas = np.linspace(0.0, 2.0 * np.pi, 20) all_bin_combs = extract_all_photoz_bin_combinations(sacc_data) for bin_comb in all_bin_combs: - two_point_xis = TwoPointXiTheta(XY=bin_comb, thetas=thetas) + two_point_xis = TwoPointReal(XY=bin_comb, thetas=thetas) assert two_point_xis.thetas is not None assert_array_equal(two_point_xis.thetas, thetas) @@ -595,76 +591,91 @@ def test_make_xis(sacc_galaxy_xis): assert two_point_xis.XY == bin_comb -def test_extract_all_two_point_cells(sacc_galaxy_cells): +def test_extract_all_harmonic_data(sacc_galaxy_cells): sacc_data, _, tracer_pairs = sacc_galaxy_cells - two_point_cells, two_point_cwindows = extract_all_data_cells(sacc_data) - assert len(two_point_cells) + len(two_point_cwindows) == len(tracer_pairs) - - assert len(two_point_cwindows) == 0 + two_point_harmonics = extract_all_harmonic_data(sacc_data) + assert len(two_point_harmonics) == len(tracer_pairs) - for two_point in two_point_cells: - tracer_names = TracerNames(two_point.XY.x.bin_name, two_point.XY.y.bin_name) - assert (tracer_names, two_point.get_sacc_name()) in tracer_pairs + for two_point in two_point_harmonics: + metadata = two_point.metadata + assert isinstance(metadata, TwoPointHarmonic) + tracer_names = TracerNames(metadata.XY.x.bin_name, metadata.XY.y.bin_name) + assert (tracer_names, metadata.get_sacc_name()) in tracer_pairs - ells, Cell = tracer_pairs[(tracer_names, two_point.get_sacc_name())] - assert_array_equal(two_point.ells, ells) - assert two_point.Cell is not None - assert_array_equal(two_point.Cell.data, Cell) + ells, Cell = tracer_pairs[(tracer_names, metadata.get_sacc_name())] + assert_array_equal(metadata.ells, ells) + assert_array_equal(two_point.data, Cell) - check_two_point_consistence_harmonic(two_point_cells) + check_two_point_consistence_harmonic(two_point_harmonics) -def test_extract_all_two_point_cwindows(sacc_galaxy_cwindows): +def test_extract_all_harmonic_with_window_data(sacc_galaxy_cwindows): sacc_data, _, tracer_pairs = sacc_galaxy_cwindows - two_point_cells, two_point_cwindows = extract_all_data_cells(sacc_data) - assert len(two_point_cells) + len(two_point_cwindows) == len(tracer_pairs) - assert len(two_point_cells) == 0 + two_point_harmonics = extract_all_harmonic_data(sacc_data) + assert len(two_point_harmonics) == len(tracer_pairs) - for two_point in two_point_cwindows: - tracer_names = TracerNames(two_point.XY.x.bin_name, two_point.XY.y.bin_name) - assert (tracer_names, two_point.get_sacc_name()) in tracer_pairs + for two_point in two_point_harmonics: + metadata = two_point.metadata + assert isinstance(metadata, TwoPointHarmonic) + tracer_names = TracerNames(metadata.XY.x.bin_name, metadata.XY.y.bin_name) + assert (tracer_names, metadata.get_sacc_name()) in tracer_pairs - ells, Cell, window = tracer_pairs[(tracer_names, two_point.get_sacc_name())] - assert_array_equal(two_point.window.ells, ells) - assert two_point.Cell is not None - assert_array_equal(two_point.Cell.data, Cell) - assert two_point.window is not None + ells, Cell, window = tracer_pairs[(tracer_names, metadata.get_sacc_name())] + assert_array_equal(metadata.ells, ells) + assert_array_equal(two_point.data, Cell) + assert metadata.window is not None - assert_array_equal( - two_point.window.weights, window.weight / window.weight.sum(axis=0) - ) - assert_array_equal(two_point.window.ells, window.values) + assert_array_equal(metadata.window, window.weight / window.weight.sum(axis=0)) + assert_array_equal(metadata.ells, window.values) - check_two_point_consistence_harmonic(two_point_cwindows) + check_two_point_consistence_harmonic(two_point_harmonics) -def test_extract_all_data_xi_thetas(sacc_galaxy_xis): +def test_extract_all_real_data(sacc_galaxy_xis): sacc_data, _, tracer_pairs = sacc_galaxy_xis - two_point_xis = extract_all_data_xi_thetas(sacc_data) + two_point_xis = extract_all_real_data(sacc_data) assert len(two_point_xis) == len(tracer_pairs) for two_point in two_point_xis: - tracer_names = TracerNames(two_point.XY.x.bin_name, two_point.XY.y.bin_name) - assert (tracer_names, two_point.get_sacc_name()) in tracer_pairs + metadata = two_point.metadata + assert isinstance(metadata, TwoPointReal) + tracer_names = TracerNames(metadata.XY.x.bin_name, metadata.XY.y.bin_name) + assert (tracer_names, metadata.get_sacc_name()) in tracer_pairs - thetas, xi = tracer_pairs[(tracer_names, two_point.get_sacc_name())] - assert_array_equal(two_point.thetas, thetas) - assert two_point.xis is not None - assert_array_equal(two_point.xis.data, xi) + thetas, xi = tracer_pairs[(tracer_names, metadata.get_sacc_name())] + assert_array_equal(metadata.thetas, thetas) + assert_array_equal(two_point.data, xi) check_two_point_consistence_real(two_point_xis) -def test_constructor_cells(sacc_galaxy_cells, wl_factory, nc_factory): +def test_constructor_harmonic_metadata(sacc_galaxy_cells, wl_factory, nc_factory): sacc_data, _, tracer_pairs = sacc_galaxy_cells - two_point_cells, _ = extract_all_data_cells(sacc_data) + two_point_harmonics = extract_all_harmonic_metadata(sacc_data) + + two_points_new = TwoPoint.from_metadata(two_point_harmonics, wl_factory, nc_factory) - two_points_new = TwoPoint.from_metadata_harmonic( - two_point_cells, wl_factory, nc_factory, check_consistence=True + assert two_points_new is not None + + for two_point in two_points_new: + tracer_pairs_key = (two_point.sacc_tracers, two_point.sacc_data_type) + assert tracer_pairs_key in tracer_pairs + assert two_point.ells is not None + assert_array_equal(two_point.ells, tracer_pairs[tracer_pairs_key][0]) + + +def test_constructor_harmonic_data(sacc_galaxy_cells, wl_factory, nc_factory): + sacc_data, _, tracer_pairs = sacc_galaxy_cells + + two_point_harmonics = extract_all_harmonic_data(sacc_data) + + check_two_point_consistence_harmonic(two_point_harmonics) + two_points_new = TwoPoint.from_measurement( + two_point_harmonics, wl_factory, nc_factory ) assert two_points_new is not None @@ -679,13 +690,41 @@ def test_constructor_cells(sacc_galaxy_cells, wl_factory, nc_factory): ) -def test_constructor_cwindows(sacc_galaxy_cwindows, wl_factory, nc_factory): +def test_constructor_harmonic_with_window_metadata( + sacc_galaxy_cwindows, wl_factory, nc_factory +): + sacc_data, _, tracer_pairs = sacc_galaxy_cwindows + + two_point_harmonics = extract_all_harmonic_metadata(sacc_data) + + two_points_new = TwoPoint.from_metadata(two_point_harmonics, wl_factory, nc_factory) + + assert two_points_new is not None + + for two_point in two_points_new: + tracer_pairs_key = (two_point.sacc_tracers, two_point.sacc_data_type) + _, _, window = tracer_pairs[tracer_pairs_key] + + assert tracer_pairs_key in tracer_pairs + assert two_point.ells is not None + assert two_point.window is not None + assert_array_equal( + two_point.window, + window.weight / window.weight.sum(axis=0), + ) + assert_array_equal(two_point.ells, window.values) + + +def test_constructor_harmonic_with_window_data( + sacc_galaxy_cwindows, wl_factory, nc_factory +): sacc_data, _, tracer_pairs = sacc_galaxy_cwindows - _, two_point_cwindows = extract_all_data_cells(sacc_data) + two_point_harmonics = extract_all_harmonic_data(sacc_data) - two_points_new = TwoPoint.from_metadata_harmonic( - two_point_cwindows, wl_factory, nc_factory + check_two_point_consistence_harmonic(two_point_harmonics) + two_points_new = TwoPoint.from_measurement( + two_point_harmonics, wl_factory, nc_factory ) assert two_points_new is not None @@ -695,24 +734,39 @@ def test_constructor_cwindows(sacc_galaxy_cwindows, wl_factory, nc_factory): _, Cell, window = tracer_pairs[tracer_pairs_key] assert tracer_pairs_key in tracer_pairs - assert two_point.ells is None + assert two_point.ells is not None assert_array_equal(two_point.get_data_vector(), Cell) assert two_point.window is not None assert_array_equal( - two_point.window.weights, + two_point.window, window.weight / window.weight.sum(axis=0), ) - assert_array_equal(two_point.window.ells, window.values) + assert_array_equal(two_point.ells, window.values) -def test_constructor_xis(sacc_galaxy_xis, wl_factory, nc_factory): +def test_constructor_reals_metadata(sacc_galaxy_xis, wl_factory, nc_factory): sacc_data, _, tracer_pairs = sacc_galaxy_xis - two_point_xis = extract_all_data_xi_thetas(sacc_data) + two_point_reals = extract_all_real_metadata(sacc_data) - two_points_new = TwoPoint.from_metadata_real( - two_point_xis, wl_factory, nc_factory, check_consistence=True - ) + two_points_new = TwoPoint.from_metadata(two_point_reals, wl_factory, nc_factory) + + assert two_points_new is not None + + for two_point in two_points_new: + tracer_pairs_key = (two_point.sacc_tracers, two_point.sacc_data_type) + assert tracer_pairs_key in tracer_pairs + assert two_point.thetas is not None + assert_array_equal(two_point.thetas, tracer_pairs[tracer_pairs_key][0]) + + +def test_constructor_reals_data(sacc_galaxy_xis, wl_factory, nc_factory): + sacc_data, _, tracer_pairs = sacc_galaxy_xis + + two_point_reals = extract_all_real_data(sacc_data) + + check_two_point_consistence_real(two_point_reals) + two_points_new = TwoPoint.from_measurement(two_point_reals, wl_factory, nc_factory) assert two_points_new is not None @@ -726,31 +780,31 @@ def test_constructor_xis(sacc_galaxy_xis, wl_factory, nc_factory): ) -def test_compare_constructors_cells( +def test_compare_constructors_harmonic( sacc_galaxy_cells, wl_factory, nc_factory, tools_with_vanilla_cosmology ): sacc_data, _, _ = sacc_galaxy_cells - two_point_cells, _ = extract_all_data_cells(sacc_data) + two_point_harmonics = extract_all_harmonic_data(sacc_data) - two_points_harmonic = TwoPoint.from_metadata_harmonic( - two_point_cells, wl_factory, nc_factory + two_points_harmonic = TwoPoint.from_measurement( + two_point_harmonics, wl_factory, nc_factory ) params = ParamsMap(two_points_harmonic.required_parameters().get_default_values()) two_points_old = [] - for cell in two_point_cells: - sacc_data_type = cell.get_sacc_name() + for tpm in two_point_harmonics: + sacc_data_type = tpm.metadata.get_sacc_name() source0 = use_source_factory( - cell.XY.x, - cell.XY.x_measurement, + tpm.metadata.XY.x, + tpm.metadata.XY.x_measurement, wl_factory=wl_factory, nc_factory=nc_factory, ) source1 = use_source_factory( - cell.XY.y, - cell.XY.y_measurement, + tpm.metadata.XY.y, + tpm.metadata.XY.y_measurement, wl_factory=wl_factory, nc_factory=nc_factory, ) @@ -780,31 +834,31 @@ def test_compare_constructors_cells( ) -def test_compare_constructors_cwindows( +def test_compare_constructors_harmonic_with_window( sacc_galaxy_cwindows, wl_factory, nc_factory, tools_with_vanilla_cosmology ): sacc_data, _, _ = sacc_galaxy_cwindows - _, two_point_cwindows = extract_all_data_cells(sacc_data) + two_point_harmonics = extract_all_harmonic_data(sacc_data) - two_points_harmonic = TwoPoint.from_metadata_harmonic( - two_point_cwindows, wl_factory, nc_factory + two_points_harmonic = TwoPoint.from_measurement( + two_point_harmonics, wl_factory, nc_factory ) params = ParamsMap(two_points_harmonic.required_parameters().get_default_values()) two_points_old = [] - for cwindow in two_point_cwindows: - sacc_data_type = cwindow.get_sacc_name() + for tpm in two_point_harmonics: + sacc_data_type = tpm.metadata.get_sacc_name() source0 = use_source_factory( - cwindow.XY.x, - cwindow.XY.x_measurement, + tpm.metadata.XY.x, + tpm.metadata.XY.x_measurement, wl_factory=wl_factory, nc_factory=nc_factory, ) source1 = use_source_factory( - cwindow.XY.y, - cwindow.XY.y_measurement, + tpm.metadata.XY.y, + tpm.metadata.XY.y_measurement, wl_factory=wl_factory, nc_factory=nc_factory, ) @@ -817,15 +871,16 @@ def test_compare_constructors_cwindows( for two_point, two_point_old in zip(two_points_harmonic, two_points_old): assert two_point.sacc_tracers == two_point_old.sacc_tracers assert two_point.sacc_data_type == two_point_old.sacc_data_type - assert two_point.ells is None assert ( two_point_old.ells is not None ) # The old constructor always sets the ells assert_array_equal(two_point.get_data_vector(), two_point_old.get_data_vector()) assert two_point.window is not None assert two_point_old.window is not None - assert_array_equal(two_point.window.weights, two_point_old.window.weights) - assert_array_equal(two_point.window.ells, two_point_old.window.ells) + assert_array_equal(two_point.window, two_point_old.window) + assert two_point.ells is not None + assert two_point_old.ells is not None + assert_array_equal(two_point.ells, two_point_old.ells) assert two_point.sacc_indices is not None assert two_point_old.sacc_indices is not None assert_array_equal(two_point.sacc_indices, two_point_old.sacc_indices) @@ -839,25 +894,31 @@ def test_compare_constructors_cwindows( ) -def test_compare_constructors_xis( +def test_compare_constructors_reals( sacc_galaxy_xis, wl_factory, nc_factory, tools_with_vanilla_cosmology ): sacc_data, _, _ = sacc_galaxy_xis - two_point_xis = extract_all_data_xi_thetas(sacc_data) + two_point_xis = extract_all_real_data(sacc_data) - two_points_real = TwoPoint.from_metadata_real(two_point_xis, wl_factory, nc_factory) + two_points_real = TwoPoint.from_measurement(two_point_xis, wl_factory, nc_factory) params = ParamsMap(two_points_real.required_parameters().get_default_values()) two_points_old = [] - for xi in two_point_xis: - sacc_data_type = xi.get_sacc_name() + for tpm in two_point_xis: + sacc_data_type = tpm.metadata.get_sacc_name() source0 = use_source_factory( - xi.XY.x, xi.XY.x_measurement, wl_factory=wl_factory, nc_factory=nc_factory + tpm.metadata.XY.x, + tpm.metadata.XY.x_measurement, + wl_factory=wl_factory, + nc_factory=nc_factory, ) source1 = use_source_factory( - xi.XY.y, xi.XY.y_measurement, wl_factory=wl_factory, nc_factory=nc_factory + tpm.metadata.XY.y, + tpm.metadata.XY.y_measurement, + wl_factory=wl_factory, + nc_factory=nc_factory, ) two_point = TwoPoint(sacc_data_type, source0, source1) two_point.read(sacc_data) @@ -885,11 +946,11 @@ def test_compare_constructors_xis( ) -def test_extract_all_data_cells_no_cov(sacc_galaxy_cells): +def test_extract_all_data_harmonic_no_cov(sacc_galaxy_cells): sacc_data, _, _ = sacc_galaxy_cells sacc_data.covariance = None with pytest.raises( ValueError, match=("The SACC object does not have a covariance matrix."), ): - _ = extract_all_data_cells(sacc_data, include_maybe_types=True) + _ = extract_all_harmonic_data(sacc_data, include_maybe_types=True) diff --git a/tests/metadata/test_metadata_two_point_types.py b/tests/metadata/test_metadata_two_point_types.py index ffce7814..7152469c 100644 --- a/tests/metadata/test_metadata_two_point_types.py +++ b/tests/metadata/test_metadata_two_point_types.py @@ -1,5 +1,5 @@ """ -Tests for the module firecrown.metadata.two_point +Tests for the module firecrown.metadata_types and firecrown.metadata_functions. """ from dataclasses import replace @@ -12,34 +12,39 @@ import sacc import sacc_name_mapping as snm -from firecrown.metadata.two_point_types import ( - Galaxies, - Clusters, +from firecrown.metadata_types import ( + ALL_MEASUREMENTS, CMB, + Clusters, + Galaxies, + LENS_REGEX, + SOURCE_REGEX, TracerNames, - TwoPointMeasurement, - ALL_MEASUREMENTS, + TwoPointHarmonic, + TwoPointReal, compare_enums, - type_to_sacc_string_harmonic as harmonic, - type_to_sacc_string_real as real, measurement_is_compatible as is_compatible, - measurement_is_compatible_real as is_compatible_real, measurement_is_compatible_harmonic as is_compatible_harmonic, + measurement_is_compatible_real as is_compatible_real, measurement_supports_harmonic as supports_harmonic, measurement_supports_real as supports_real, + type_to_sacc_string_harmonic as harmonic, + type_to_sacc_string_real as real, ) -from firecrown.metadata.two_point import ( - extract_all_tracers_types, - measurements_from_index, - LENS_REGEX, - SOURCE_REGEX, - TwoPointXiThetaIndex, +from firecrown.metadata_functions import ( + TwoPointRealIndex, + extract_all_measured_types, match_name_type, + measurements_from_index, +) + +from firecrown.data_types import TwoPointMeasurement +from firecrown.data_functions import ( + extract_all_real_data, + extract_all_harmonic_data, check_two_point_consistence_harmonic, check_two_point_consistence_real, - extract_all_data_cells, - extract_all_data_xi_thetas, ) @@ -169,7 +174,7 @@ def test_extract_all_tracers_types_cells( ): sacc_data, _, _ = sacc_galaxy_cells - tracers = extract_all_tracers_types(sacc_data) + tracers = extract_all_measured_types(sacc_data) for tracer, measurements in tracers.items(): if LENS_REGEX.match(tracer): @@ -185,7 +190,7 @@ def test_extract_all_tracers_types_cwindows( ): sacc_data, _, _ = sacc_galaxy_cwindows - tracers = extract_all_tracers_types(sacc_data) + tracers = extract_all_measured_types(sacc_data) for tracer, measurements in tracers.items(): if LENS_REGEX.match(tracer): @@ -196,12 +201,10 @@ def test_extract_all_tracers_types_cwindows( assert measurement == Galaxies.SHEAR_E -def test_extract_all_tracers_types_xi_thetas( - sacc_galaxy_xis: tuple[sacc.Sacc, dict, dict] -): +def test_extract_all_tracers_types_reals(sacc_galaxy_xis: tuple[sacc.Sacc, dict, dict]): sacc_data, _, _ = sacc_galaxy_xis - tracers = extract_all_tracers_types(sacc_data) + tracers = extract_all_measured_types(sacc_data) for tracer, measurements in tracers.items(): if LENS_REGEX.match(tracer): @@ -215,12 +218,12 @@ def test_extract_all_tracers_types_xi_thetas( } -def test_extract_all_tracers_types_xi_thetas_inverted( +def test_extract_all_tracers_types_reals_inverted( sacc_galaxy_xis_inverted: tuple[sacc.Sacc, dict, dict] ): sacc_data, _, _ = sacc_galaxy_xis_inverted - tracers = extract_all_tracers_types(sacc_data) + tracers = extract_all_measured_types(sacc_data) for tracer, measurements in tracers.items(): if LENS_REGEX.match(tracer): @@ -239,9 +242,9 @@ def test_extract_all_tracers_types_cells_include_maybe( ): sacc_data, _, _ = sacc_galaxy_cells - assert extract_all_tracers_types( + assert extract_all_measured_types( sacc_data, include_maybe_types=True - ) == extract_all_tracers_types(sacc_data) + ) == extract_all_measured_types(sacc_data) def test_extract_all_tracers_types_cwindows_include_maybe( @@ -249,9 +252,9 @@ def test_extract_all_tracers_types_cwindows_include_maybe( ): sacc_data, _, _ = sacc_galaxy_cwindows - assert extract_all_tracers_types( + assert extract_all_measured_types( sacc_data, include_maybe_types=True - ) == extract_all_tracers_types(sacc_data) + ) == extract_all_measured_types(sacc_data) def test_extract_all_tracers_types_xi_thetas_include_maybe( @@ -259,16 +262,15 @@ def test_extract_all_tracers_types_xi_thetas_include_maybe( ): sacc_data, _, _ = sacc_galaxy_xis - assert extract_all_tracers_types( + assert extract_all_measured_types( sacc_data, include_maybe_types=True - ) == extract_all_tracers_types(sacc_data) + ) == extract_all_measured_types(sacc_data) def test_measurements_from_index1(): - index: TwoPointXiThetaIndex = { + index: TwoPointRealIndex = { "data_type": "galaxy_shearDensity_xi_t", "tracer_names": TracerNames("src0", "lens0"), - "thetas": np.linspace(0.0, 1.0, 100), } n1, a, n2, b = measurements_from_index(index) assert n1 == "lens0" @@ -278,10 +280,9 @@ def test_measurements_from_index1(): def test_measurements_from_index2(): - index: TwoPointXiThetaIndex = { + index: TwoPointRealIndex = { "data_type": "galaxy_shearDensity_xi_t", "tracer_names": TracerNames("lens0", "src0"), - "thetas": np.linspace(0.0, 1.0, 100), } n1, a, n2, b = measurements_from_index(index) assert n1 == "lens0" @@ -419,156 +420,163 @@ def test_match_name_type_require_convention_source(): assert b == Galaxies.SHEAR_MINUS -def test_check_two_point_consistence_harmonic(two_point_cell): - Cell = TwoPointMeasurement( - data=np.zeros(100), indices=np.arange(100), covariance_name="cov" +def test_check_two_point_consistence_harmonic(two_point_cell: TwoPointHarmonic): + tpm = TwoPointMeasurement( + data=np.zeros(100), + indices=np.arange(100), + covariance_name="cov", + metadata=two_point_cell, ) - check_two_point_consistence_harmonic([replace(two_point_cell, Cell=Cell)]) + check_two_point_consistence_harmonic([tpm]) -def test_check_two_point_consistence_harmonic_missing_cell(two_point_cell): +def test_check_two_point_consistence_harmonic_real(two_point_real: TwoPointHarmonic): + tpm = TwoPointMeasurement( + data=np.zeros(100), + indices=np.arange(100), + covariance_name="cov", + metadata=two_point_real, + ) with pytest.raises( ValueError, - match="The TwoPointCells \\(.*, .*\\)\\[.*\\] does not contain a data.", + match=(".*is not a measurement of TwoPointHarmonic."), ): - check_two_point_consistence_harmonic([two_point_cell]) + check_two_point_consistence_harmonic([tpm]) -def test_check_two_point_consistence_real(two_point_xi_theta): - xis = TwoPointMeasurement( - data=np.zeros(100), indices=np.arange(100), covariance_name="cov" +def test_check_two_point_consistence_real(two_point_real: TwoPointReal): + tpm = TwoPointMeasurement( + data=np.zeros(100), + indices=np.arange(100), + covariance_name="cov", + metadata=two_point_real, ) - check_two_point_consistence_real([replace(two_point_xi_theta, xis=xis)]) + check_two_point_consistence_real([tpm]) -def test_check_two_point_consistence_real_missing_xis(two_point_xi_theta): +def test_check_two_point_consistence_real_harmonic(two_point_cell: TwoPointReal): + tpm = TwoPointMeasurement( + data=np.zeros(100), + indices=np.arange(100), + covariance_name="cov", + metadata=two_point_cell, + ) with pytest.raises( ValueError, - match="The TwoPointXiTheta \\(.*, .*\\)\\[.*\\] does not contain a data.", + match=(".*is not a measurement of TwoPointReal."), ): - check_two_point_consistence_real([two_point_xi_theta]) + check_two_point_consistence_real([tpm]) -def test_check_two_point_consistence_harmonic_mixing_cov(sacc_galaxy_cells): +def test_check_two_point_consistence_harmonic_mixing_cov( + sacc_galaxy_cells: tuple[sacc.Sacc, dict, dict] +): sacc_data, _, _ = sacc_galaxy_cells - two_point_cells, _ = extract_all_data_cells(sacc_data) + tpms = extract_all_harmonic_data(sacc_data) - assert two_point_cells[0].Cell is not None - two_point_cells[0] = replace( - two_point_cells[0], - Cell=replace(two_point_cells[0].Cell, covariance_name="wrong_cov_name"), - ) + tpms[0] = replace(tpms[0], covariance_name="wrong_cov_name") with pytest.raises( ValueError, match=( - "The TwoPointCells .* has a different covariance name .* " - "than the previous TwoPointCells wrong_cov_name." + ".* has a different covariance name .* " + "than the previous TwoPointHarmonic wrong_cov_name." ), ): - check_two_point_consistence_harmonic(two_point_cells) + check_two_point_consistence_harmonic(tpms) -def test_check_two_point_consistence_real_mixing_cov(sacc_galaxy_xis): +def test_check_two_point_consistence_real_mixing_cov( + sacc_galaxy_xis: tuple[sacc.Sacc, dict, dict] +): sacc_data, _, _ = sacc_galaxy_xis - two_point_xis = extract_all_data_xi_thetas(sacc_data) - assert two_point_xis[0].xis is not None - two_point_xis[0] = replace( - two_point_xis[0], - xis=replace(two_point_xis[0].xis, covariance_name="wrong_cov_name"), - ) + tpms = extract_all_real_data(sacc_data) + tpms[0] = replace(tpms[0], covariance_name="wrong_cov_name") with pytest.raises( ValueError, match=( - "The TwoPointXiTheta .* has a different covariance name .* than the " - "previous TwoPointXiTheta wrong_cov_name." + ".* has a different covariance name .* than the " + "previous TwoPointReal wrong_cov_name." ), ): - check_two_point_consistence_real(two_point_xis) + check_two_point_consistence_real(tpms) -def test_check_two_point_consistence_harmonic_non_unique_indices(sacc_galaxy_cells): +def test_check_two_point_consistence_harmonic_non_unique_indices( + sacc_galaxy_cells: tuple[sacc.Sacc, dict, dict] +): sacc_data, _, _ = sacc_galaxy_cells - two_point_cells, _ = extract_all_data_cells(sacc_data) + tpms = extract_all_harmonic_data(sacc_data) - assert two_point_cells[0].Cell is not None - new_indices = two_point_cells[0].Cell.indices + new_indices = tpms[0].indices new_indices[0] = 3 - two_point_cells[0] = replace( - two_point_cells[0], - Cell=replace(two_point_cells[0].Cell, indices=new_indices), - ) + tpms[0] = replace(tpms[0], indices=new_indices) with pytest.raises( ValueError, - match="The indices of the TwoPointCells .* are not unique.", + match=".* are not unique.", ): - check_two_point_consistence_harmonic(two_point_cells) + check_two_point_consistence_harmonic(tpms) -def test_check_two_point_consistence_real_non_unique_indices(sacc_galaxy_xis): +def test_check_two_point_consistence_real_non_unique_indices( + sacc_galaxy_xis: tuple[sacc.Sacc, dict, dict] +): sacc_data, _, _ = sacc_galaxy_xis - two_point_xis = extract_all_data_xi_thetas(sacc_data) - assert two_point_xis[0].xis is not None - new_indices = two_point_xis[0].xis.indices + tpms = extract_all_real_data(sacc_data) + new_indices = tpms[0].indices new_indices[0] = 3 - two_point_xis[0] = replace( - two_point_xis[0], - xis=replace(two_point_xis[0].xis, indices=new_indices), - ) + tpms[0] = replace(tpms[0], indices=new_indices) with pytest.raises( ValueError, - match="The indices of the TwoPointXiTheta .* are not unique.", + match=".* are not unique.", ): - check_two_point_consistence_real(two_point_xis) + check_two_point_consistence_real(tpms) -def test_check_two_point_consistence_harmonic_indices_overlap(sacc_galaxy_cells): +def test_check_two_point_consistence_harmonic_indices_overlap( + sacc_galaxy_cells: tuple[sacc.Sacc, dict, dict] +): sacc_data, _, _ = sacc_galaxy_cells - two_point_cells, _ = extract_all_data_cells(sacc_data) + tpms = extract_all_harmonic_data(sacc_data) - assert two_point_cells[1].Cell is not None - new_indices = two_point_cells[1].Cell.indices + new_indices = tpms[1].indices new_indices[1] = 3 - two_point_cells[1] = replace( - two_point_cells[1], - Cell=replace(two_point_cells[1].Cell, indices=new_indices), - ) + tpms[1] = replace(tpms[1], indices=new_indices) with pytest.raises( ValueError, - match="The indices of the TwoPointCells .* overlap.", + match=".* overlap.", ): - check_two_point_consistence_harmonic(two_point_cells) + check_two_point_consistence_harmonic(tpms) -def test_check_two_point_consistence_real_indices_overlap(sacc_galaxy_xis): +def test_check_two_point_consistence_real_indices_overlap( + sacc_galaxy_xis: tuple[sacc.Sacc, dict, dict] +): sacc_data, _, _ = sacc_galaxy_xis - two_point_xis = extract_all_data_xi_thetas(sacc_data) - assert two_point_xis[1].xis is not None - new_indices = two_point_xis[1].xis.indices + tpms = extract_all_real_data(sacc_data) + + new_indices = tpms[1].indices new_indices[1] = 3 - two_point_xis[1] = replace( - two_point_xis[1], - xis=replace(two_point_xis[1].xis, indices=new_indices), - ) + tpms[1] = replace(tpms[1], indices=new_indices) with pytest.raises( ValueError, - match="The indices of the TwoPointXiTheta .* overlap.", + match=".* overlap.", ): - check_two_point_consistence_real(two_point_xis) + check_two_point_consistence_real(tpms) -def test_extract_all_data_cells_ambiguous(sacc_galaxy_cells_ambiguous): +def test_extract_all_data_cells_ambiguous(sacc_galaxy_cells_ambiguous: sacc.Sacc): with pytest.raises( ValueError, match=( @@ -576,6 +584,6 @@ def test_extract_all_data_cells_ambiguous(sacc_galaxy_cells_ambiguous): "determine which measurement is from which tracer." ), ): - _ = extract_all_data_cells( + _ = extract_all_harmonic_data( sacc_galaxy_cells_ambiguous, include_maybe_types=True ) diff --git a/tutorial/inferred_zdist.qmd b/tutorial/inferred_zdist.qmd index 35d001a5..42eb7380 100644 --- a/tutorial/inferred_zdist.qmd +++ b/tutorial/inferred_zdist.qmd @@ -25,7 +25,7 @@ The `InferredGalaxyZDist` object serves as the cornerstone for representing the The `InferredGalaxyZDist` object can be created by providing the following parameters: ```{python} -from firecrown.metadata.two_point_types import Galaxies, InferredGalaxyZDist +from firecrown.metadata_types import Galaxies, InferredGalaxyZDist import numpy as np z = np.linspace(0.0, 1.0, 200) @@ -48,7 +48,7 @@ The `TwoPointCell` object is used to encapsulate the two-point correlation funct Using the `lens0` object created above, we can create a `TwoPointCell` object as follows: ```{python} -from firecrown.metadata.two_point import TwoPointXY, TwoPointCells +from firecrown.metadata_types import TwoPointXY, TwoPointHarmonic from firecrown.generators.two_point import LogLinearElls lens0_lens0 = TwoPointXY( @@ -61,12 +61,12 @@ lens0_lens0 = TwoPointXY( # terms. Firecrown allows their inclusion, if you so desire. ells_generator = LogLinearElls(minimum=2, midpoint=20, maximum=200, n_log=20) -lens0_lens0_cell = TwoPointCells(XY=lens0_lens0, ells=ells_generator.generate()) +lens0_lens0_cell = TwoPointHarmonic(XY=lens0_lens0, ells=ells_generator.generate()) ``` -The `TwoPointCells` is created by providing the `TwoPointXY` object and the ells array. -The reason for separating the `TwoPointXY` and `TwoPointCells` objects is to accommodate the creation of other types of two-point correlation functions. -For example, from a `TwoPointXY`, one can also create a `TwoPointXiTheta` object for the two-point correlation function in real space, or a `TwoPointCWindow` object for the two-point correlation function in harmonic space with a window function. +The `TwoPointHarmonic` is created by providing the `TwoPointXY` object and the ells array. +The reason for separating the `TwoPointXY` and `TwoPointHarmonic` objects is to accommodate the creation of other types of two-point correlation functions. +For example, from a `TwoPointXY`, one can also create a `TwoPointReal` object for the two-point correlation function in real space, or a `TwoPointCWindow` object for the two-point correlation function in harmonic space with a window function. ## Creating a TwoPoint Theory Object @@ -77,7 +77,7 @@ Using the `lens0_lens0_cell` object created above, we can create a `TwoPoint` ob from firecrown.likelihood.two_point import TwoPoint, NumberCountsFactory # Generate a list of TwoPoint objects, one for each bin provided -all_two_points = TwoPoint.from_metadata_harmonic( +all_two_points = TwoPoint.from_metadata( [lens0_lens0_cell], # Note we are passing a list with 1 bin; you could pass several nc_factory=NumberCountsFactory(global_systematics=[], per_bin_systematics=[]), ) diff --git a/tutorial/inferred_zdist_generators.qmd b/tutorial/inferred_zdist_generators.qmd index 92bc1415..7f3dd70d 100644 --- a/tutorial/inferred_zdist_generators.qmd +++ b/tutorial/inferred_zdist_generators.qmd @@ -23,7 +23,7 @@ An object of this type contains the redshifts, the corresponding probabilities, Firecrown also provides facilities to generate these distributions, given a chosen model. In `firecrown.generators`, we have the `ZDistLSSTSRD` class, which can be used to generate redshift distributions according to the LSST SRD. -[^1]: The metadata classes described in this tutorial, unless otherwise noted, are defined in the package `firecrown.metadata`. +[^1]: The metadata classes described in this tutorial, unless otherwise noted, are defined in the package `firecrown.metadata_types`. ## LSST SRD Distribution @@ -118,7 +118,7 @@ from firecrown.generators.inferred_galaxy_zdist import ( Y1_SOURCE_BINS, Y10_SOURCE_BINS, ) -from firecrown.metadata.two_point_types import Measurement, Galaxies +from firecrown.metadata_types import Measurement, Galaxies # These are the values at which we will evaluate the distribution. z = np.linspace(0.0, 0.6, 100) diff --git a/tutorial/inferred_zdist_serialization.qmd b/tutorial/inferred_zdist_serialization.qmd index 96d4c16b..062265f4 100644 --- a/tutorial/inferred_zdist_serialization.qmd +++ b/tutorial/inferred_zdist_serialization.qmd @@ -27,7 +27,7 @@ The code snippet below demonstrates how to initialize and serialize this datacla ```{python} from firecrown.generators.inferred_galaxy_zdist import ZDistLSSTSRDBin, LinearGrid1D -from firecrown.metadata.two_point_types import Galaxies +from firecrown.metadata_types import Galaxies from firecrown.utils import base_model_to_yaml z = LinearGrid1D(start=0.01, end=0.5, num=20) diff --git a/tutorial/two_point_factories.qmd b/tutorial/two_point_factories.qmd index a0bb2faf..bfe8eb88 100644 --- a/tutorial/two_point_factories.qmd +++ b/tutorial/two_point_factories.qmd @@ -27,14 +27,14 @@ One shortcoming of this methodology is that the user must know the `SACC` object Then, a matching `SACC` object must be passed in the `read` phase. To simplify this, we introduced functions to extract the metadata from the `SACC` file, such that it can be used toguether with the factories to create the objects. -In the code below we extract the metadata from the `examples/des_y1_3x2pt` `SACC` file. +In the code below we extract the metadata indices from the `examples/des_y1_3x2pt` `SACC` file. ```{python} -from firecrown.metadata.two_point import extract_all_data_types_xi_thetas +from firecrown.metadata_functions import extract_all_real_metadata_indices import sacc sacc_data = sacc.Sacc.load_fits("../examples/des_y1_3x2pt/des_y1_3x2pt_sacc_data.fits") -all_meta = extract_all_data_types_xi_thetas(sacc_data) +all_meta = extract_all_real_metadata_indices(sacc_data) ``` The metadata can be seem below: @@ -92,8 +92,10 @@ global_systematics: [] weak_lensing_factory = base_model_from_yaml(WeakLensingFactory, weak_lensing_yaml) number_counts_factory = base_model_from_yaml(NumberCountsFactory, number_counts_yaml) -two_point_list = TwoPoint.from_metadata_only_real( - metadata=all_meta, wl_factory=weak_lensing_factory, nc_factory=number_counts_factory +two_point_list = TwoPoint.from_metadata_index( + metadata_indices=all_meta, + wl_factory=weak_lensing_factory, + nc_factory=number_counts_factory, ) ``` @@ -120,13 +122,17 @@ This data may be extracted from a `SACC` object, generated by a custom generator In the code below we extract all components from the `examples/des_y1_3x2pt` `SACC` file. ```{python} -from firecrown.metadata.two_point import extract_all_data_xi_thetas +from firecrown.data_functions import ( + extract_all_real_data, + check_two_point_consistence_real, +) from pprint import pprint -two_point_xi_thetas = extract_all_data_xi_thetas(sacc_data) +two_point_reals = extract_all_real_data(sacc_data) +check_two_point_consistence_real(two_point_reals) ``` -The list `two_point_xi_thetas` contains all information about real-space two point functions contained in the `SACC` object. +The list `two_point_reals` contains all information about real-space two point functions contained in the `SACC` object. Therefore, we need to use a different constructor to obtain the two-point objects already in the `ready` state. ```{python} from firecrown.likelihood.two_point import TwoPoint @@ -134,11 +140,11 @@ from firecrown.likelihood.two_point import TwoPoint weak_lensing_factory = base_model_from_yaml(WeakLensingFactory, weak_lensing_yaml) number_counts_factory = base_model_from_yaml(NumberCountsFactory, number_counts_yaml) -two_points_ready = TwoPoint.from_metadata_real( - two_point_xi_thetas, + +two_points_ready = TwoPoint.from_measurement( + two_point_reals, wl_factory=weak_lensing_factory, nc_factory=number_counts_factory, - check_consistence=True, ) ``` diff --git a/tutorial/two_point_generators.qmd b/tutorial/two_point_generators.qmd index bb5da5ac..dd38577d 100644 --- a/tutorial/two_point_generators.qmd +++ b/tutorial/two_point_generators.qmd @@ -37,17 +37,17 @@ all_y1_bins = count_bins + shear_bins ## Generating the Two-Point Metadata -Within `firecrown.metadata`, a suite of functions is available to calculate all possible two-point statistic metadata corresponding to a designated set of bins. +Within `firecrown.metadata_functions`, a suite of functions is available to calculate all possible two-point statistic metadata corresponding to a designated set of bins. For instance, demonstrated below is the computation of all feasible metadata tailored to the LSST year 1 redshift distribution: ```{python} import numpy as np -from firecrown.metadata.two_point import make_all_photoz_bin_combinations, TwoPointCells +from firecrown.metadata_functions import make_all_photoz_bin_combinations, TwoPointHarmonic all_two_point_xy = make_all_photoz_bin_combinations(all_y1_bins) ells = np.unique(np.geomspace(2, 2000, 128).astype(int)) -all_two_point_cells = [TwoPointCells(XY=xy, ells=ells) for xy in all_two_point_xy] +all_two_point_cells = [TwoPointHarmonic(XY=xy, ells=ells) for xy in all_two_point_xy] ``` @@ -138,8 +138,8 @@ These factories are responsible for producing the necessary systematics applied Below is the code defining the factories for weak lensing and number counts systematics: ```{python} -all_two_point_functions = tp.TwoPoint.from_metadata_harmonic( - metadata=all_two_point_cells, +all_two_point_functions = tp.TwoPoint.from_metadata( + metadata_seq=all_two_point_cells, wl_factory=wlf, nc_factory=ncf, ) From 417c94cecaff2ae27e4f01e0993a2221673f7e7b Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Wed, 11 Sep 2024 18:49:44 -0300 Subject: [PATCH 07/10] Adding support for NumberCounts factories including RSD. (#449) * Adding support for NumberCounts factories including RSD. * Checking if RSD parameter is correctly set. --- firecrown/likelihood/number_counts.py | 9 +- .../gauss_family/statistic/test_two_point.py | 177 +++++++++++------- 2 files changed, 115 insertions(+), 71 deletions(-) diff --git a/firecrown/likelihood/number_counts.py b/firecrown/likelihood/number_counts.py index fe4b528f..ecc7406f 100644 --- a/firecrown/likelihood/number_counts.py +++ b/firecrown/likelihood/number_counts.py @@ -646,6 +646,7 @@ class NumberCountsFactory(BaseModel): per_bin_systematics: Sequence[NumberCountsSystematicFactory] global_systematics: Sequence[NumberCountsSystematicFactory] + include_rsd: bool = False def model_post_init(self, _) -> None: """Initialize the NumberCountsFactory. @@ -674,7 +675,9 @@ def create(self, inferred_zdist: InferredGalaxyZDist) -> NumberCounts: ] systematics.extend(self._global_systematics_instances) - nc = NumberCounts.create_ready(inferred_zdist, systematics=systematics) + nc = NumberCounts.create_ready( + inferred_zdist, systematics=systematics, has_rsd=self.include_rsd + ) self._cache[inferred_zdist_id] = nc return nc @@ -697,7 +700,9 @@ def create_from_metadata_only( ] systematics.extend(self._global_systematics_instances) - nc = NumberCounts(sacc_tracer=sacc_tracer, systematics=systematics) + nc = NumberCounts( + sacc_tracer=sacc_tracer, systematics=systematics, has_rsd=self.include_rsd + ) self._cache[sacc_tracer_id] = nc return nc diff --git a/tests/likelihood/gauss_family/statistic/test_two_point.py b/tests/likelihood/gauss_family/statistic/test_two_point.py index f71fda8b..599769c3 100644 --- a/tests/likelihood/gauss_family/statistic/test_two_point.py +++ b/tests/likelihood/gauss_family/statistic/test_two_point.py @@ -45,6 +45,12 @@ from firecrown.data_types import TwoPointMeasurement +@pytest.fixture(name="include_rsd", params=[True, False], ids=["rsd", "no_rsd"]) +def fixture_include_rsd(request) -> bool: + """Return whether to include RSD in the test.""" + return request.param + + @pytest.fixture(name="source_0") def fixture_source_0() -> NumberCounts: """Return an almost-default NumberCounts source.""" @@ -98,20 +104,60 @@ def fixture_harmonic_data_no_window(harmonic_two_point_xy) -> TwoPointMeasuremen return tpm -def test_ell_for_xi_no_rounding(): +@pytest.fixture(name="wl_factory") +def fixture_wl_factory() -> WeakLensingFactory: + """Return a WeakLensingFactory object.""" + return WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) + + +@pytest.fixture(name="nc_factory") +def fixture_nc_factory(include_rsd: bool) -> NumberCountsFactory: + """Return a NumberCountsFactory object.""" + return NumberCountsFactory( + per_bin_systematics=[], global_systematics=[], include_rsd=include_rsd + ) + + +@pytest.fixture(name="two_point_with_window") +def fixture_two_point_with_window( + harmonic_data_with_window: TwoPointMeasurement, + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +) -> TwoPoint: + """Return a TwoPoint object with a window.""" + two_points = TwoPoint.from_measurement( + [harmonic_data_with_window], wl_factory=wl_factory, nc_factory=nc_factory + ) + return two_points.pop() + + +@pytest.fixture(name="two_point_without_window") +def fixture_two_point_without_window( + harmonic_data_no_window: TwoPointMeasurement, + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +) -> TwoPoint: + """Return a TwoPoint object without a window.""" + two_points = TwoPoint.from_measurement( + [harmonic_data_no_window], wl_factory=wl_factory, nc_factory=nc_factory + ) + return two_points.pop() + + +def test_ell_for_xi_no_rounding() -> None: res = _ell_for_xi(minimum=0, midpoint=5, maximum=80, n_log=5) expected = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 40.0, 80.0]) assert res.shape == expected.shape assert np.allclose(expected, res) -def test_ell_for_xi_doing_rounding(): +def test_ell_for_xi_doing_rounding() -> None: res = _ell_for_xi(minimum=1, midpoint=3, maximum=100, n_log=5) expected = np.array([1.0, 2.0, 3.0, 7.0, 17.0, 42.0, 100.0]) assert np.allclose(expected, res) -def test_compute_theory_vector(source_0: NumberCounts): +def test_compute_theory_vector(source_0: NumberCounts) -> None: # To create the TwoPoint object we need at least one source. statistic = TwoPoint("galaxy_density_xi", source_0, source_0) assert isinstance(statistic, TwoPoint) @@ -122,7 +168,7 @@ def test_compute_theory_vector(source_0: NumberCounts): # assert isinstance(prediction, TheoryVector) -def test_tracer_names(): +def test_tracer_names() -> None: assert TracerNames("", "") == TRACER_NAMES_TOTAL tn1 = TracerNames("cow", "pig") @@ -137,7 +183,7 @@ def test_tracer_names(): _ = tn1[2] -def test_two_point_src0_src0_window(sacc_galaxy_cells_src0_src0_window): +def test_two_point_src0_src0_window(sacc_galaxy_cells_src0_src0_window) -> None: """This test also makes sure that TwoPoint theory calculations are repeatable.""" sacc_data, _, _ = sacc_galaxy_cells_src0_src0_window @@ -166,7 +212,7 @@ def test_two_point_src0_src0_window(sacc_galaxy_cells_src0_src0_window): assert np.array_equal(result1, result2) -def test_two_point_src0_src0_no_window(sacc_galaxy_cells_src0_src0_no_window): +def test_two_point_src0_src0_no_window(sacc_galaxy_cells_src0_src0_no_window) -> None: """This test also makes sure that TwoPoint theory calculations are repeatable.""" sacc_data, _, _ = sacc_galaxy_cells_src0_src0_no_window @@ -195,7 +241,7 @@ def test_two_point_src0_src0_no_window(sacc_galaxy_cells_src0_src0_no_window): assert np.array_equal(result1, result2) -def test_two_point_src0_src0_no_data_lin(sacc_galaxy_cells_src0_src0_no_data): +def test_two_point_src0_src0_no_data_lin(sacc_galaxy_cells_src0_src0_no_data) -> None: sacc_data, _, _ = sacc_galaxy_cells_src0_src0_no_data src0 = WeakLensing(sacc_tracer="src0") @@ -224,7 +270,7 @@ def test_two_point_src0_src0_no_data_lin(sacc_galaxy_cells_src0_src0_no_data): assert all(statistic.ells <= 100) -def test_two_point_src0_src0_no_data_log(sacc_galaxy_cells_src0_src0_no_data): +def test_two_point_src0_src0_no_data_log(sacc_galaxy_cells_src0_src0_no_data) -> None: sacc_data, _, _ = sacc_galaxy_cells_src0_src0_no_data src0 = WeakLensing(sacc_tracer="src0") @@ -253,7 +299,7 @@ def test_two_point_src0_src0_no_data_log(sacc_galaxy_cells_src0_src0_no_data): assert all(statistic.ells <= 100) -def test_two_point_lens0_lens0_no_data(sacc_galaxy_xis_lens0_lens0_no_data): +def test_two_point_lens0_lens0_no_data(sacc_galaxy_xis_lens0_lens0_no_data) -> None: sacc_data, _, _ = sacc_galaxy_xis_lens0_lens0_no_data src0 = NumberCounts(sacc_tracer="lens0") @@ -282,7 +328,7 @@ def test_two_point_lens0_lens0_no_data(sacc_galaxy_xis_lens0_lens0_no_data): assert all(statistic.thetas <= 1.0) -def test_two_point_src0_src0_cuts(sacc_galaxy_cells_src0_src0): +def test_two_point_src0_src0_cuts(sacc_galaxy_cells_src0_src0) -> None: sacc_data, _, _ = sacc_galaxy_cells_src0_src0 src0 = WeakLensing(sacc_tracer="src0") @@ -309,7 +355,7 @@ def test_two_point_src0_src0_cuts(sacc_galaxy_cells_src0_src0): statistic.compute_theory_vector(tools) -def test_two_point_lens0_lens0_cuts(sacc_galaxy_xis_lens0_lens0): +def test_two_point_lens0_lens0_cuts(sacc_galaxy_xis_lens0_lens0) -> None: sacc_data, _, _ = sacc_galaxy_xis_lens0_lens0 src0 = NumberCounts(sacc_tracer="lens0") @@ -334,7 +380,7 @@ def test_two_point_lens0_lens0_cuts(sacc_galaxy_xis_lens0_lens0): statistic.compute_theory_vector(tools) -def test_two_point_lens0_lens0_config(sacc_galaxy_xis_lens0_lens0): +def test_two_point_lens0_lens0_config(sacc_galaxy_xis_lens0_lens0) -> None: sacc_data, _, _ = sacc_galaxy_xis_lens0_lens0 src0 = NumberCounts(sacc_tracer="lens0") @@ -365,7 +411,7 @@ def test_two_point_lens0_lens0_config(sacc_galaxy_xis_lens0_lens0): statistic.compute_theory_vector(tools) -def test_two_point_src0_src0_no_data_error(sacc_galaxy_cells_src0_src0_no_data): +def test_two_point_src0_src0_no_data_error(sacc_galaxy_cells_src0_src0_no_data) -> None: sacc_data, _, _ = sacc_galaxy_cells_src0_src0_no_data src0 = WeakLensing(sacc_tracer="src0") @@ -382,7 +428,9 @@ def test_two_point_src0_src0_no_data_error(sacc_galaxy_cells_src0_src0_no_data): statistic.read(sacc_data) -def test_two_point_lens0_lens0_no_data_error(sacc_galaxy_xis_lens0_lens0_no_data): +def test_two_point_lens0_lens0_no_data_error( + sacc_galaxy_xis_lens0_lens0_no_data, +) -> None: sacc_data, _, _ = sacc_galaxy_xis_lens0_lens0_no_data src0 = NumberCounts(sacc_tracer="lens0") @@ -399,7 +447,9 @@ def test_two_point_lens0_lens0_no_data_error(sacc_galaxy_xis_lens0_lens0_no_data statistic.read(sacc_data) -def test_two_point_src0_src0_data_and_conf_warn(sacc_galaxy_cells_src0_src0_window): +def test_two_point_src0_src0_data_and_conf_warn( + sacc_galaxy_cells_src0_src0_window, +) -> None: sacc_data, _, _ = sacc_galaxy_cells_src0_src0_window src0 = WeakLensing(sacc_tracer="src0") @@ -422,7 +472,7 @@ def test_two_point_src0_src0_data_and_conf_warn(sacc_galaxy_cells_src0_src0_wind statistic.read(sacc_data) -def test_two_point_lens0_lens0_data_and_conf_warn(sacc_galaxy_xis_lens0_lens0): +def test_two_point_lens0_lens0_data_and_conf_warn(sacc_galaxy_xis_lens0_lens0) -> None: sacc_data, _, _ = sacc_galaxy_xis_lens0_lens0 src0 = NumberCounts(sacc_tracer="lens0") @@ -445,22 +495,26 @@ def test_two_point_lens0_lens0_data_and_conf_warn(sacc_galaxy_xis_lens0_lens0): statistic.read(sacc_data) -def test_use_source_factory(harmonic_bin_1: InferredGalaxyZDist): - wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) - nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) - +def test_use_source_factory( + harmonic_bin_1: InferredGalaxyZDist, + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +) -> None: measurement = list(harmonic_bin_1.measurements)[0] source = use_source_factory(harmonic_bin_1, measurement, wl_factory, nc_factory) if measurement in GALAXY_LENS_TYPES: assert isinstance(source, NumberCounts) + assert source.has_rsd == nc_factory.include_rsd elif measurement in GALAXY_SOURCE_TYPES: assert isinstance(source, WeakLensing) else: assert False, f"Unknown measurement type: {measurement}" -def test_use_source_factory_invalid_measurement(harmonic_bin_1: InferredGalaxyZDist): +def test_use_source_factory_invalid_measurement( + harmonic_bin_1: InferredGalaxyZDist, +) -> None: with pytest.raises( ValueError, match="Measurement .* not found in inferred galaxy redshift distribution .*", @@ -468,41 +522,45 @@ def test_use_source_factory_invalid_measurement(harmonic_bin_1: InferredGalaxyZD use_source_factory(harmonic_bin_1, Galaxies.SHEAR_MINUS, None, None) -def test_use_source_factory_metadata_only_counts(): - wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) - nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) +def test_use_source_factory_metadata_only_counts( + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +) -> None: source = use_source_factory_metadata_index( "bin1", Galaxies.COUNTS, wl_factory=wl_factory, nc_factory=nc_factory ) assert isinstance(source, NumberCounts) + assert source.has_rsd == nc_factory.include_rsd -def test_use_source_factory_metadata_only_shear(): - wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) - nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) +def test_use_source_factory_metadata_only_shear( + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +) -> None: source = use_source_factory_metadata_index( "bin1", Galaxies.SHEAR_E, wl_factory=wl_factory, nc_factory=nc_factory ) assert isinstance(source, WeakLensing) -def test_use_source_factory_metadata_only_invalid_measurement(): - wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) - nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) +def test_use_source_factory_metadata_only_invalid_measurement( + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +) -> None: with pytest.raises(ValueError, match="Unknown measurement type encountered .*"): use_source_factory_metadata_index( "bin1", 120, wl_factory=wl_factory, nc_factory=nc_factory # type: ignore ) -def test_two_point_wrong_type(): +def test_two_point_wrong_type() -> None: with pytest.raises(ValueError, match="The SACC data type cow is not supported!"): TwoPoint( "cow", WeakLensing(sacc_tracer="calma"), WeakLensing(sacc_tracer="fernando") ) -def test_from_metadata_harmonic_wrong_metadata(): +def test_from_metadata_harmonic_wrong_metadata() -> None: with pytest.raises( ValueError, match=re.escape("Metadata of type is not supported") ): @@ -511,7 +569,7 @@ def test_from_metadata_harmonic_wrong_metadata(): ) -def test_use_source_factory_metadata_only_wrong_measurement(): +def test_use_source_factory_metadata_only_wrong_measurement() -> None: unknown_type = MagicMock() unknown_type.configure_mock(__eq__=MagicMock(return_value=False)) @@ -521,9 +579,10 @@ def test_use_source_factory_metadata_only_wrong_measurement(): ) -def test_from_metadata_only_harmonic(): - wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) - nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) +def test_from_metadata_only_harmonic( + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +) -> None: metadata: TwoPointHarmonicIndex = { "data_type": "galaxy_density_xi", "tracer_names": TracerNames("lens0", "lens0"), @@ -537,9 +596,10 @@ def test_from_metadata_only_harmonic(): assert not two_point.ready -def test_from_metadata_only_real(): - wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) - nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) +def test_from_metadata_only_real( + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +) -> None: metadata: TwoPointRealIndex = { "data_type": "galaxy_shear_xi_plus", "tracer_names": TracerNames("src0", "src0"), @@ -554,50 +614,29 @@ def test_from_metadata_only_real(): def test_from_measurement_compute_theory_vector_window( - harmonic_data_with_window: TwoPointMeasurement, -): - wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) - nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) - two_points = TwoPoint.from_measurement( - [harmonic_data_with_window], wl_factory=wl_factory, nc_factory=nc_factory - ) - two_point: TwoPoint = two_points.pop() - - assert isinstance(two_point, TwoPoint) - assert two_point.ready + two_point_with_window: TwoPoint, +) -> None: + assert isinstance(two_point_with_window, TwoPoint) + assert two_point_with_window.ready - req_params = two_point.required_parameters() + req_params = two_point_with_window.required_parameters() default_values = req_params.get_default_values() params = ParamsMap(default_values) tools = ModelingTools() tools.update(params) tools.prepare(pyccl.CosmologyVanillaLCDM()) - two_point.update(params) + two_point_with_window.update(params) - prediction = two_point.compute_theory_vector(tools) + prediction = two_point_with_window.compute_theory_vector(tools) assert isinstance(prediction, TheoryVector) assert prediction.shape == (4,) def test_from_measurement_compute_theory_vector_window_check( - harmonic_data_with_window: TwoPointMeasurement, - harmonic_data_no_window: TwoPointMeasurement, -): - wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) - nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) - - two_points_with_window = TwoPoint.from_measurement( - [harmonic_data_with_window], wl_factory=wl_factory, nc_factory=nc_factory - ) - two_point_with_window: TwoPoint = two_points_with_window.pop() - - two_points_without_window = TwoPoint.from_measurement( - [harmonic_data_no_window], wl_factory=wl_factory, nc_factory=nc_factory - ) - two_point_without_window: TwoPoint = two_points_without_window.pop() - + two_point_with_window: TwoPoint, two_point_without_window: TwoPoint +) -> None: assert isinstance(two_point_with_window, TwoPoint) assert two_point_with_window.ready From eab916f48fa7d5ab020c950abd483b29e8aebf10 Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Fri, 20 Sep 2024 08:08:35 -0500 Subject: [PATCH 08/10] Fix typo in variable name (#453) --- firecrown/generators/inferred_galaxy_zdist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firecrown/generators/inferred_galaxy_zdist.py b/firecrown/generators/inferred_galaxy_zdist.py index 220972a6..0e0fe06f 100644 --- a/firecrown/generators/inferred_galaxy_zdist.py +++ b/firecrown/generators/inferred_galaxy_zdist.py @@ -395,7 +395,7 @@ def generate(self) -> list[InferredGalaxyZDist]: ], ) -LSSST_Y10_SOURCE_BIN_COLLECTION = ZDistLSSTSRDBinCollection( +LSST_Y10_SOURCE_BIN_COLLECTION = ZDistLSSTSRDBinCollection( alpha=Y10_ALPHA, beta=Y10_BETA, z0=Y10_Z0, From 1e6972f031390f1f2171a9b8d16ec5ab6412e7ff Mon Sep 17 00:00:00 2001 From: Marc Paterno Date: Mon, 23 Sep 2024 15:14:54 -0500 Subject: [PATCH 09/10] Temporary tweaks to fix CI failures from pylint 3.3 (#457) --- firecrown/models/pylintrc | 2 ++ pylintrc | 1 + 2 files changed, 3 insertions(+) diff --git a/firecrown/models/pylintrc b/firecrown/models/pylintrc index 645fb1a9..a07f982f 100644 --- a/firecrown/models/pylintrc +++ b/firecrown/models/pylintrc @@ -51,6 +51,8 @@ function-naming-style=any [DESIGN] +# Maximum number of positional arguments for function / method. +max-positional-arguments=6 # Maximum number of arguments for function / method. max-args=11 diff --git a/pylintrc b/pylintrc index 769b8f48..8792b18c 100644 --- a/pylintrc +++ b/pylintrc @@ -67,5 +67,6 @@ ignore-none=false expected-line-ending-format=LF [DESIGN] +max-positional-arguments=8 max-args=20 max-attributes=20 From 097769cd0b6f6cd182e74d320cf64513d933881c Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Tue, 24 Sep 2024 18:11:27 -0300 Subject: [PATCH 10/10] New firecrown.likelihood.factories module for general factories. (#450) Add new likelihood factories module and CCLFactory implementation * Introduced `firecrown.likelihood.factories` module for general likelihood factories, along with examples. * Created `CCLFactory` object and adapted existing tests and connectors to support it. * Added tests for `NumCosmo` connector, `ccl_factory`, and other relevant configurations. * Updated old and untested examples, including tutorials. * Improved code clarity and exception handling, including fixing exception chaining and message testing. * Refactored experiment settings and cleaned up generated files. * Added configuration for new `pylint` options and removed repeated options. * Conducted extensive testing on parameter validation, including invalid likelihood configurations and different combinations of PK. --- examples/des_y1_3x2pt/des_y1_3x2pt.py | 6 +- examples/des_y1_3x2pt/des_y1_3x2pt_PT.py | 18 +- .../des_y1_3x2pt_default_factory.ini | 60 +++ .../des_y1_3x2pt/des_y1_3x2pt_experiment.yaml | 29 ++ .../des_y1_3x2pt/des_y1_cosmic_shear_TATT.py | 64 +-- .../des_y1_cosmic_shear_pk_modifier.py | 46 +-- firecrown/ccl_factory.py | 233 +++++++++++ firecrown/connector/cobaya/ccl.py | 45 ++- firecrown/connector/cobaya/likelihood.py | 26 +- firecrown/connector/cosmosis/likelihood.py | 29 +- firecrown/connector/mapping.py | 63 ++- firecrown/connector/numcosmo/numcosmo.py | 148 ++++--- firecrown/likelihood/factories.py | 224 +++++++++++ firecrown/likelihood/likelihood.py | 42 +- firecrown/modeling_tools.py | 12 +- firecrown/parameters.py | 42 +- firecrown/updatable.py | 32 +- tests/conftest.py | 8 +- .../connector/cobaya/test_model_likelihood.py | 20 +- .../cosmosis/test_cosmosis_module.py | 26 -- .../numcosmo/test_numcosmo_connector.py | 10 +- .../numcosmo/test_numcosmo_mapping.py | 128 ++---- tests/connector/test_mapping.py | 2 - tests/integration/test_des_y1_3x2pt.py | 1 + .../gauss_family/lkscript_const_gaussian.py | 9 +- .../gauss_family/lkscript_two_point.py | 9 +- .../test_binned_cluster_number_counts.py | 20 +- .../gauss_family/statistic/test_supernova.py | 16 +- .../gauss_family/statistic/test_two_point.py | 71 ++-- .../likelihood/lkdir/lk_derived_parameter.py | 7 +- tests/likelihood/lkdir/lk_needing_param.py | 8 +- .../likelihood/lkdir/lk_sampler_parameter.py | 7 +- tests/likelihood/lkdir/lkscript.py | 5 +- tests/likelihood/test_factories.py | 315 +++++++++++++++ tests/likelihood/test_likelihood.py | 17 +- tests/test_bug398.py | 58 +-- tests/test_bug_413.py | 30 +- tests/test_ccl_factory.py | 367 ++++++++++++++++++ tests/test_modeling_tools.py | 50 +-- tests/test_parameters.py | 16 +- tests/test_pt_systematics.py | 33 +- tutorial/inferred_zdist.qmd | 12 +- tutorial/two_point_factories.qmd | 11 +- tutorial/two_point_generators.qmd | 24 +- 44 files changed, 1856 insertions(+), 543 deletions(-) create mode 100644 examples/des_y1_3x2pt/des_y1_3x2pt_default_factory.ini create mode 100644 examples/des_y1_3x2pt/des_y1_3x2pt_experiment.yaml create mode 100644 firecrown/ccl_factory.py create mode 100644 firecrown/likelihood/factories.py create mode 100644 tests/likelihood/test_factories.py create mode 100644 tests/test_ccl_factory.py diff --git a/examples/des_y1_3x2pt/des_y1_3x2pt.py b/examples/des_y1_3x2pt/des_y1_3x2pt.py index d51e79b2..fb10ffa4 100755 --- a/examples/des_y1_3x2pt/des_y1_3x2pt.py +++ b/examples/des_y1_3x2pt/des_y1_3x2pt.py @@ -7,6 +7,8 @@ import firecrown.likelihood.number_counts as nc from firecrown.likelihood.two_point import TwoPoint from firecrown.likelihood.gaussian import ConstGaussian +from firecrown.modeling_tools import ModelingTools +from firecrown.ccl_factory import CCLFactory # The likelihood used for DES Y1 3x2pt analysis is a Gaussian likelihood, which @@ -111,6 +113,8 @@ def build_likelihood(_): # file and the sources their respective dndz. likelihood.read(sacc_data) + modeling_tools = ModelingTools(ccl_factory=CCLFactory(require_nonlinear_pk=True)) + # This script will be loaded by the appropriated connector. The framework # will call the factory function that should return a Likelihood instance. - return likelihood + return likelihood, modeling_tools diff --git a/examples/des_y1_3x2pt/des_y1_3x2pt_PT.py b/examples/des_y1_3x2pt/des_y1_3x2pt_PT.py index 8e507639..9431b988 100755 --- a/examples/des_y1_3x2pt/des_y1_3x2pt_PT.py +++ b/examples/des_y1_3x2pt/des_y1_3x2pt_PT.py @@ -20,6 +20,8 @@ from firecrown.parameters import ParamsMap from firecrown.modeling_tools import ModelingTools from firecrown.likelihood.likelihood import Likelihood +from firecrown.ccl_factory import CCLFactory +from firecrown.updatable import get_default_params_map saccfile = os.path.expanduser( @@ -137,7 +139,9 @@ def build_likelihood(_) -> tuple[Likelihood, ModelingTools]: nk_per_decade=20, ) - modeling_tools = ModelingTools(pt_calculator=pt_calculator) + modeling_tools = ModelingTools( + pt_calculator=pt_calculator, ccl_factory=CCLFactory(require_nonlinear_pk=True) + ) # Note that the ordering of the statistics is relevant, because the data # vector, theory vector, and covariance matrix will be organized to # follow the order used here. @@ -173,9 +177,11 @@ def run_likelihood() -> None: z, nz = src0_tracer.z, src0_tracer.nz lens_z, lens_nz = lens0_tracer.z, lens0_tracer.nz - # Define a ccl.Cosmology object using default parameters - ccl_cosmo = ccl.CosmologyVanillaLCDM() - ccl_cosmo.compute_nonlin_power() + # Setting cosmological parameters + cosmo_params = get_default_params_map(tools) + tools.update(cosmo_params) + tools.prepare() + ccl_cosmo = tools.get_ccl_cosmology() cs = CclSetup() c_1, c_d, c_2 = pyccl.nl_pt.translate_IA_norm( @@ -223,10 +229,6 @@ def run_likelihood() -> None: # Apply the systematics parameters likelihood.update(systematics_params) - tools.update(systematics_params) - - # Prepare the cosmology object - tools.prepare(ccl_cosmo) # Compute the log-likelihood, using the ccl.Cosmology object as the input log_like = likelihood.compute_loglike(tools) diff --git a/examples/des_y1_3x2pt/des_y1_3x2pt_default_factory.ini b/examples/des_y1_3x2pt/des_y1_3x2pt_default_factory.ini new file mode 100644 index 00000000..8f6a211c --- /dev/null +++ b/examples/des_y1_3x2pt/des_y1_3x2pt_default_factory.ini @@ -0,0 +1,60 @@ +[runtime] +sampler = test +root = ${PWD} + +[DEFAULT] +fatal_errors = T + +[output] +filename = output/des_y1_3x2pt_samples.txt +format = text +verbosity = 0 + +[pipeline] +modules = consistency camb firecrown_likelihood +values = ${FIRECROWN_DIR}/examples/des_y1_3x2pt/des_y1_3x2pt_values.ini +likelihoods = firecrown +quiet = T +debug = T +timing = T +extra_output = TwoPoint/NumberCountsScale_lens0 TwoPoint/NumberCountsScale_lens1 TwoPoint/NumberCountsScale_lens2 TwoPoint/NumberCountsScale_lens3 TwoPoint/NumberCountsScale_lens4 + +[consistency] +file = ${CSL_DIR}/utility/consistency/consistency_interface.py + +[camb] +file = ${CSL_DIR}/boltzmann/camb/camb_interface.py + +mode = all +lmax = 2500 +feedback = 0 +zmin = 0.0 +zmax = 4.0 +nz = 100 +kmin = 1e-4 +kmax = 50.0 +nk = 1000 + +[firecrown_likelihood] +;; Fix this to use an environment variable to find the files. +;; Set FIRECROWN_DIR to the base of the firecrown installation (or build, if you haven't +;; installed it) +file = ${FIRECROWN_DIR}/firecrown/connector/cosmosis/likelihood.py +likelihood_source = firecrown.likelihood.factories.build_two_point_likelihood +likelihood_config = ${FIRECROWN_DIR}/examples/des_y1_3x2pt/des_y1_3x2pt_experiment.yaml +;; Connector settings +require_nonlinear_pk = True +sampling_parameters_sections = firecrown_two_point + +[test] +fatal_errors = T +save_dir = des_y1_3x2pt_output + +[metropolis] +samples = 1000 +nsteps = 1 + +[emcee] +walkers = 64 +samples = 400 +nsteps = 10 diff --git a/examples/des_y1_3x2pt/des_y1_3x2pt_experiment.yaml b/examples/des_y1_3x2pt/des_y1_3x2pt_experiment.yaml new file mode 100644 index 00000000..298317a3 --- /dev/null +++ b/examples/des_y1_3x2pt/des_y1_3x2pt_experiment.yaml @@ -0,0 +1,29 @@ +--- + +# This file contains all information about the DES Y1 3x2pt experiment. + +# The 'data_source:' field points to the SACC file containing the DES Y1 3x2pt data vector +# and covariance matrix, which is used for analysis. + +# The 'two_point_factory:' field points to the factory that will create the TwoPoint +# objects. These objects represent the chosen theoretical models for each data point +# found in the SACC file. + +data_source: + sacc_data_file: des_y1_3x2pt_sacc_data.fits + +# The two point statistics are defined by the TwoPoint objects. The TwoPoint statistics +# are created using the factories defined in this file. +two_point_factory: + correlation_space: real + number_counts_factory: + global_systematics: [] + per_bin_systematics: + - type: PhotoZShiftFactory + weak_lensing_factory: + global_systematics: + - alphag: 1 + type: LinearAlignmentSystematicFactory + per_bin_systematics: + - type: MultiplicativeShearBiasFactory + - type: PhotoZShiftFactory diff --git a/examples/des_y1_3x2pt/des_y1_cosmic_shear_TATT.py b/examples/des_y1_3x2pt/des_y1_cosmic_shear_TATT.py index 640726aa..647268bd 100644 --- a/examples/des_y1_3x2pt/des_y1_cosmic_shear_TATT.py +++ b/examples/des_y1_3x2pt/des_y1_cosmic_shear_TATT.py @@ -11,7 +11,9 @@ from firecrown.parameters import ParamsMap from firecrown.modeling_tools import ModelingTools from firecrown.likelihood.likelihood import Likelihood - +from firecrown.ccl_factory import CCLFactory +from firecrown.updatable import get_default_params_map +from firecrown.metadata_types import TracerNames, TRACER_NAMES_TOTAL SACCFILE = os.path.expanduser( os.path.expandvars( @@ -38,7 +40,9 @@ def build_likelihood(_) -> tuple[Likelihood, ModelingTools]: nk_per_decade=20, ) - modeling_tools = ModelingTools(pt_calculator=pt_calculator) + modeling_tools = ModelingTools( + pt_calculator=pt_calculator, ccl_factory=CCLFactory(require_nonlinear_pk=True) + ) likelihood = ConstGaussian(statistics=list(stats.values())) # Read the two-point data from the sacc file @@ -102,14 +106,33 @@ def run_likelihood() -> None: src0_tracer = sacc_data.get_tracer("src0") z, nz = src0_tracer.z, src0_tracer.nz - # Define a ccl.Cosmology object using default parameters - ccl_cosmo = ccl.CosmologyVanillaLCDM() - ccl_cosmo.compute_nonlin_power() - # Bare CCL setup a_1 = 1.0 a_2 = 0.5 a_d = 0.5 + # Set the parameters for our systematics + systematics_params = ParamsMap( + { + "ia_a_1": a_1, + "ia_a_2": a_2, + "ia_a_d": a_d, + "src0_delta_z": 0.000, + "src1_delta_z": 0.003, + "src2_delta_z": -0.001, + "src3_delta_z": 0.002, + } + ) + # Prepare the cosmology object + params = ParamsMap(get_default_params_map(tools) | systematics_params) + + # Apply the systematics parameters + likelihood.update(params) + + # Prepare the cosmology object + tools.update(params) + tools.prepare() + ccl_cosmo = tools.get_ccl_cosmology() + c_1, c_d, c_2 = pyccl.nl_pt.translate_IA_norm( ccl_cosmo, z=z, a1=a_1, a1delta=a_d, a2=a_2, Om_m2_for_c2=False ) @@ -131,25 +154,6 @@ def run_likelihood() -> None: pk_im = ptc.get_biased_pk2d(tracer1=ptt_i, tracer2=ptt_m) pk_ii = ptc.get_biased_pk2d(tracer1=ptt_i, tracer2=ptt_i) - # Set the parameters for our systematics - systematics_params = ParamsMap( - { - "ia_a_1": a_1, - "ia_a_2": a_2, - "ia_a_d": a_d, - "src0_delta_z": 0.000, - "src1_delta_z": 0.003, - "src2_delta_z": -0.001, - "src3_delta_z": 0.002, - } - ) - - # Apply the systematics parameters - likelihood.update(systematics_params) - - # Prepare the cosmology object - tools.prepare(ccl_cosmo) - # Compute the log-likelihood, using the ccl.Cosmology object as the input log_like = likelihood.compute_loglike(tools) @@ -180,11 +184,11 @@ def make_plot(ccl_cosmo, nz, pk_ii, pk_im, two_point_0, z): import numpy as np # pylint: disable-msg=import-outside-toplevel import matplotlib.pyplot as plt # pylint: disable-msg=import-outside-toplevel - ells = two_point_0.ells - cells_gg = two_point_0.cells[("shear", "shear")] - cells_gi = two_point_0.cells[("shear", "intrinsic_pt")] - cells_ii = two_point_0.cells[("intrinsic_pt", "intrinsic_pt")] - cells_total = two_point_0.cells["total"] + ells = two_point_0.ells_for_xi + cells_gg = two_point_0.cells[TracerNames("shear", "shear")] + cells_gi = two_point_0.cells[TracerNames("shear", "intrinsic_pt")] + cells_ii = two_point_0.cells[TracerNames("intrinsic_pt", "intrinsic_pt")] + cells_total = two_point_0.cells[TRACER_NAMES_TOTAL] # pylint: enable=no-member # Code that computes effect from IA using that Pk2D object t_lens = ccl.WeakLensingTracer(ccl_cosmo, dndz=(z, nz)) diff --git a/examples/des_y1_3x2pt/des_y1_cosmic_shear_pk_modifier.py b/examples/des_y1_3x2pt/des_y1_cosmic_shear_pk_modifier.py index 6166b1fc..40620eb6 100644 --- a/examples/des_y1_3x2pt/des_y1_cosmic_shear_pk_modifier.py +++ b/examples/des_y1_3x2pt/des_y1_cosmic_shear_pk_modifier.py @@ -12,10 +12,12 @@ import firecrown.likelihood.weak_lensing as wl from firecrown.likelihood.two_point import TwoPoint from firecrown.likelihood.gaussian import ConstGaussian -from firecrown.parameters import ParamsMap, create +from firecrown.parameters import ParamsMap, register_new_updatable_parameter from firecrown.modeling_tools import ModelingTools, PowerspectrumModifier from firecrown.likelihood.likelihood import Likelihood - +from firecrown.ccl_factory import CCLFactory +from firecrown.updatable import get_default_params_map +from firecrown.metadata_types import TracerNames SACCFILE = os.path.expanduser( os.path.expandvars( @@ -36,7 +38,7 @@ def __init__(self, pk_to_modify: str = "delta_matter:delta_matter"): super().__init__() self.pk_to_modify = pk_to_modify self.vD19 = pyccl.baryons.BaryonsvanDaalen19() - self.f_bar = create() + self.f_bar = register_new_updatable_parameter(default_value=0.5) def compute_p_of_k_z(self, tools: ModelingTools) -> pyccl.Pk2D: """Compute the 3D power spectrum P(k, z).""" @@ -56,7 +58,9 @@ def build_likelihood(_) -> tuple[Likelihood, ModelingTools]: # Define the power spectrum modification and add it to the ModelingTools pk_modifier = vanDaalen19Baryonfication(pk_to_modify="delta_matter:delta_matter") - modeling_tools = ModelingTools(pk_modifiers=[pk_modifier]) + modeling_tools = ModelingTools( + pk_modifiers=[pk_modifier], ccl_factory=CCLFactory(require_nonlinear_pk=True) + ) # Create the likelihood from the statistics likelihood = ConstGaussian(statistics=list(stats.values())) @@ -123,18 +127,7 @@ def run_likelihood() -> None: src0_tracer = sacc_data.get_tracer("src0") z, nz = src0_tracer.z, src0_tracer.nz - # Define a ccl.Cosmology object using default parameters - ccl_cosmo = ccl.CosmologyVanillaLCDM() - ccl_cosmo.compute_nonlin_power() - f_bar = 0.5 - - # Calculate the barynic effects directly with CCL - vD19 = pyccl.BaryonsvanDaalen19(fbar=f_bar) - pk_baryons = vD19.include_baryonic_effects( - cosmo=ccl_cosmo, pk=ccl_cosmo.get_nonlin_power() - ) - # Set the parameters for our systematics systematics_params = ParamsMap( { @@ -145,13 +138,22 @@ def run_likelihood() -> None: "src3_delta_z": 0.002, } ) + # Prepare the cosmology object + params = ParamsMap(get_default_params_map(tools) | systematics_params) - # Apply the systematics parameters - likelihood.update(systematics_params) + tools.update(params) + tools.prepare() - # Prepare the cosmology object - tools.update(systematics_params) - tools.prepare(ccl_cosmo) + ccl_cosmo = tools.get_ccl_cosmology() + + # Calculate the barynic effects directly with CCL + vD19 = pyccl.BaryonsvanDaalen19(fbar=f_bar) + pk_baryons = vD19.include_baryonic_effects( + cosmo=ccl_cosmo, pk=ccl_cosmo.get_nonlin_power() + ) + + # Apply the systematics parameters + likelihood.update(params) # Compute the log-likelihood, using the ccl.Cosmology object as the input log_like = likelihood.compute_loglike(tools) @@ -166,7 +168,7 @@ def run_likelihood() -> None: # Predict CCL Cl wl_tracer = ccl.WeakLensingTracer(ccl_cosmo, dndz=(z, nz)) - ell = two_point_0.ells + ell = two_point_0.ells_for_xi cl_dm = ccl.angular_cl( cosmo=ccl_cosmo, tracer1=wl_tracer, @@ -191,7 +193,7 @@ def make_plot(ell, cl_dm, cl_baryons, two_point_0): """Create and show a diagnostic plot.""" import matplotlib.pyplot as plt # pylint: disable-msg=import-outside-toplevel - cl_firecrown = two_point_0.cells[("shear", "shear")] + cl_firecrown = two_point_0.cells[TracerNames("shear", "shear")] plt.plot(ell, cl_firecrown / cl_dm, label="firecrown w/ baryons") plt.plot(ell, cl_baryons / cl_dm, ls="--", label="CCL w/ baryons") diff --git a/firecrown/ccl_factory.py b/firecrown/ccl_factory.py new file mode 100644 index 00000000..14f03e00 --- /dev/null +++ b/firecrown/ccl_factory.py @@ -0,0 +1,233 @@ +"""This module contains the CCLFactory class. + +The CCLFactory class is a factory class that creates instances of the +`pyccl.Cosmology` class. +""" + +from typing import Annotated +from enum import Enum, auto + +# To be moved to the import from typing when migrating to Python 3.11 +from typing_extensions import NotRequired, TypedDict + +import numpy as np +import numpy.typing as npt +from pydantic import ( + BaseModel, + ConfigDict, + BeforeValidator, + SerializerFunctionWrapHandler, + SerializationInfo, + Field, + field_serializer, + model_serializer, +) + +import pyccl +from pyccl.neutrinos import NeutrinoMassSplits + +from firecrown.updatable import Updatable +from firecrown.parameters import register_new_updatable_parameter +from firecrown.utils import YAMLSerializable + +PowerSpec = TypedDict( + "PowerSpec", + { + "a": npt.NDArray[np.float64], + "k": npt.NDArray[np.float64], + "delta_matter:delta_matter": npt.NDArray[np.float64], + }, +) + +Background = TypedDict( + "Background", + { + "a": npt.NDArray[np.float64], + "chi": npt.NDArray[np.float64], + "h_over_h0": npt.NDArray[np.float64], + }, +) + +CCLCalculatorArgs = TypedDict( + "CCLCalculatorArgs", + { + "background": Background, + "pk_linear": NotRequired[PowerSpec], + "pk_nonlin": NotRequired[PowerSpec], + }, +) + + +class PoweSpecAmplitudeParameter(YAMLSerializable, str, Enum): + """This class defines the two-point correlation space. + + The two-point correlation space can be either real or harmonic. The real space + corresponds measurements in terms of angular separation, while the harmonic space + corresponds to measurements in terms of spherical harmonics decomposition. + """ + + @staticmethod + def _generate_next_value_(name, _start, _count, _last_values): + return name.lower() + + AS = auto() + SIGMA8 = auto() + + +def _validade_amplitude_parameter(value): + if isinstance(value, str): + try: + return PoweSpecAmplitudeParameter( + value.lower() + ) # Convert from string to Enum + except ValueError as exc: + raise ValueError( + f"Invalid value for PoweSpecAmplitudeParameter: {value}" + ) from exc + return value + + +def _validate_neutrino_mass_splits(value): + if isinstance(value, str): + try: + return NeutrinoMassSplits(value.lower()) # Convert from string to Enum + except ValueError as exc: + raise ValueError(f"Invalid value for NeutrinoMassSplits: {value}") from exc + return value + + +class CCLFactory(Updatable, BaseModel): + """Factory class for creating instances of the `pyccl.Cosmology` class.""" + + model_config = ConfigDict(extra="allow") + require_nonlinear_pk: Annotated[bool, Field(frozen=True)] = False + amplitude_parameter: Annotated[ + PoweSpecAmplitudeParameter, + BeforeValidator(_validade_amplitude_parameter), + Field(frozen=True), + ] = PoweSpecAmplitudeParameter.SIGMA8 + mass_split: Annotated[ + NeutrinoMassSplits, + BeforeValidator(_validate_neutrino_mass_splits), + Field(frozen=True), + ] = NeutrinoMassSplits.NORMAL + + def __init__(self, **data): + """Initialize the CCLFactory object.""" + parameter_prefix = parameter_prefix = data.pop("parameter_prefix", None) + BaseModel.__init__(self, **data) + Updatable.__init__(self, parameter_prefix=parameter_prefix) + + self._ccl_cosmo: None | pyccl.Cosmology = None + + ccl_cosmo = pyccl.CosmologyVanillaLCDM() + + self.Omega_c = register_new_updatable_parameter( + default_value=ccl_cosmo["Omega_c"] + ) + self.Omega_b = register_new_updatable_parameter( + default_value=ccl_cosmo["Omega_b"] + ) + self.h = register_new_updatable_parameter(default_value=ccl_cosmo["h"]) + self.n_s = register_new_updatable_parameter(default_value=ccl_cosmo["n_s"]) + self.Omega_k = register_new_updatable_parameter( + default_value=ccl_cosmo["Omega_k"] + ) + self.Neff = register_new_updatable_parameter(default_value=ccl_cosmo["Neff"]) + self.m_nu = register_new_updatable_parameter(default_value=ccl_cosmo["m_nu"]) + self.w0 = register_new_updatable_parameter(default_value=ccl_cosmo["w0"]) + self.wa = register_new_updatable_parameter(default_value=ccl_cosmo["wa"]) + self.T_CMB = register_new_updatable_parameter(default_value=ccl_cosmo["T_CMB"]) + + match self.amplitude_parameter: + case PoweSpecAmplitudeParameter.AS: + # VanillaLCDM has does not have A_s, so we need to add it + self.A_s = register_new_updatable_parameter(default_value=2.1e-9) + case PoweSpecAmplitudeParameter.SIGMA8: + assert ccl_cosmo["sigma8"] is not None + self.sigma8 = register_new_updatable_parameter( + default_value=ccl_cosmo["sigma8"] + ) + + @model_serializer(mode="wrap") + def serialize_model(self, nxt: SerializerFunctionWrapHandler, _: SerializationInfo): + """Serialize the CCLFactory object.""" + model_dump = nxt(self) + exclude_params = [param.name for param in self._sampler_parameters] + list( + self._internal_parameters.keys() + ) + + return {k: v for k, v in model_dump.items() if k not in exclude_params} + + @field_serializer("amplitude_parameter") + @classmethod + def serialize_amplitude_parameter(cls, value: PoweSpecAmplitudeParameter) -> str: + """Serialize the amplitude parameter.""" + return value.name + + @field_serializer("mass_split") + @classmethod + def serialize_mass_split(cls, value: NeutrinoMassSplits) -> str: + """Serialize the mass split parameter.""" + return value.name + + def model_post_init(self, __context) -> None: + """Initialize the WeakLensingFactory object.""" + + def create( + self, calculator_args: CCLCalculatorArgs | None = None + ) -> pyccl.Cosmology: + """Create a `pyccl.Cosmology` object.""" + if not self.is_updated(): + raise ValueError("Parameters have not been updated yet.") + + if self._ccl_cosmo is not None: + raise ValueError("CCLFactory object has already been created.") + + # pylint: disable=duplicate-code + ccl_args = { + "Omega_c": self.Omega_c, + "Omega_b": self.Omega_b, + "h": self.h, + "n_s": self.n_s, + "Omega_k": self.Omega_k, + "Neff": self.Neff, + "m_nu": self.m_nu, + "w0": self.w0, + "wa": self.wa, + "T_CMB": self.T_CMB, + "mass_split": self.mass_split.value, + } + # pylint: enable=duplicate-code + match self.amplitude_parameter: + case PoweSpecAmplitudeParameter.AS: + ccl_args["A_s"] = self.A_s + case PoweSpecAmplitudeParameter.SIGMA8: + ccl_args["sigma8"] = self.sigma8 + + assert ("A_s" in ccl_args) or ("sigma8" in ccl_args) + + if calculator_args is not None: + ccl_args.update(calculator_args) + if ("pk_nonlin" not in ccl_args) and self.require_nonlinear_pk: + ccl_args["nonlinear_model"] = "halofit" + else: + ccl_args["nonlinear_model"] = None + + return pyccl.CosmologyCalculator(**ccl_args) + + if self.require_nonlinear_pk: + ccl_args["matter_power_spectrum"] = "halofit" + + self._ccl_cosmo = pyccl.Cosmology(**ccl_args) + return self._ccl_cosmo + + def _reset(self) -> None: + """Reset the CCLFactory object.""" + self._ccl_cosmo = None + + def get(self) -> pyccl.Cosmology: + """Return the `pyccl.Cosmology` object.""" + if self._ccl_cosmo is None: + raise ValueError("CCLFactory object has not been created yet.") + return self._ccl_cosmo diff --git a/firecrown/connector/cobaya/ccl.py b/firecrown/connector/cobaya/ccl.py index 34a9e766..7d0f6cc1 100644 --- a/firecrown/connector/cobaya/ccl.py +++ b/firecrown/connector/cobaya/ccl.py @@ -7,11 +7,11 @@ import numpy as np import numpy.typing as npt -import pyccl from cobaya.theory import Theory from firecrown.connector.mapping import mapping_builder, MappingCAMB +from firecrown.ccl_factory import CCLCalculatorArgs class CCLConnector(Theory): @@ -110,9 +110,7 @@ def must_provide(self, **requirements) -> None: This version does nothing. """ - def calculate( - self, state: dict[str, float], want_derived=True, **params_values - ) -> None: + def calculate(self, state: dict, want_derived=True, **params_values) -> None: """Calculate the current cosmology, and set state["pyccl"] to the result. :param state: The state dictionary to update. @@ -120,10 +118,9 @@ def calculate( :param params_values: The values of the parameters to use. """ self.map.set_params_from_camb(**params_values) - pyccl_params_values = self.map.asdict() - # This is the dictionary appropriate for CCL creation + # This is the dictionary appropriate for CCL creation chi_arr = self.provider.get_comoving_radial_distance(self.z_bg) hoh0_arr = self.provider.get_Hubble(self.z_bg) / self.map.get_H0() k, z, pk = self.provider.get_Pk_grid() @@ -135,21 +132,23 @@ def calculate( self.a_Pk = self.map.redshift_to_scale_factor(z) pk_a = self.map.redshift_to_scale_factor_p_k(pk) - cosmo = pyccl.CosmologyCalculator( - **pyccl_params_values, - background={"a": self.a_bg, "chi": chi_arr, "h_over_h0": hoh0_arr}, - pk_linear={ - "a": self.a_Pk, - "k": k, - "delta_matter:delta_matter": pk_a, - }, - nonlinear_model="halofit", - ) - state["pyccl"] = cosmo - - def get_pyccl(self) -> pyccl.Cosmology: - """Return the current cosmology. - - :return: The current cosmology + pyccl_args: CCLCalculatorArgs = { + "background": {"a": self.a_bg, "chi": chi_arr, "h_over_h0": hoh0_arr}, + "pk_linear": {"a": self.a_Pk, "k": k, "delta_matter:delta_matter": pk_a}, + } + state["pyccl_args"] = pyccl_args + state["pyccl_params"] = pyccl_params_values + + def get_pyccl_args(self) -> CCLCalculatorArgs: + """Return the current CCL arguments. + + :return: The current CCL arguments + """ + return self.current_state["pyccl_args"] + + def get_pyccl_params(self) -> dict[str, float]: + """Return the current cosmological parameters. + + :return: The current cosmological parameters. """ - return self.current_state["pyccl"] + return self.current_state["pyccl_params"] diff --git a/firecrown/connector/cobaya/likelihood.py b/firecrown/connector/cobaya/likelihood.py index 16c05d98..54dcc304 100644 --- a/firecrown/connector/cobaya/likelihood.py +++ b/firecrown/connector/cobaya/likelihood.py @@ -12,6 +12,7 @@ from firecrown.likelihood.likelihood import load_likelihood, NamedParameters from firecrown.parameters import ParamsMap +from firecrown.ccl_factory import PoweSpecAmplitudeParameter class LikelihoodConnector(Likelihood): @@ -81,21 +82,29 @@ def get_requirements( """Returns a dictionary. Returns a dictionary with keys corresponding the contained likelihood's - required parameter, plus "pyccl". All values are None. + required parameter, plus "pyccl_args" and "pyccl_params". All values are None. Required by Cobaya. :return: a dictionary """ likelihood_requires: dict[ str, None | dict[str, npt.NDArray[np.float64]] | dict[str, object] - ] = {"pyccl": None} + ] = {"pyccl_args": None, "pyccl_params": None} required_params = ( self.likelihood.required_parameters() + self.tools.required_parameters() ) + # Cosmological parameters differ from Cobaya's, so we need to remove them. + required_params -= self.tools.ccl_factory.required_parameters() for param_name in required_params.get_params_names(): likelihood_requires[param_name] = None + if ( + self.tools.ccl_factory.amplitude_parameter + == PoweSpecAmplitudeParameter.SIGMA8 + ): + likelihood_requires["sigma8"] = None + return likelihood_requires def must_provide(self, **requirements) -> None: @@ -110,18 +119,21 @@ def logp(self, **params_values) -> float: Required by Cobaya. :params values: The values of the parameters to use. """ - pyccl = self.provider.get_pyccl() + pyccl_args = self.provider.get_pyccl_args() + pyccl_params = self.provider.get_pyccl_params() - self.likelihood.update(ParamsMap(params_values)) - self.tools.update(ParamsMap(params_values)) - self.tools.prepare(pyccl) + derived = params_values.pop("_derived", {}) + params = ParamsMap(params_values | pyccl_params) + self.likelihood.update(params) + self.tools.update(params) + self.tools.prepare(calculator_args=pyccl_args) loglike = self.likelihood.compute_loglike(self.tools) derived_params_collection = self.likelihood.get_derived_parameters() assert derived_params_collection is not None for section, name, val in derived_params_collection: - params_values["_derived"][f"{section}__{name}"] = val + derived[f"{section}__{name}"] = val self.likelihood.reset() self.tools.reset() diff --git a/firecrown/connector/cosmosis/likelihood.py b/firecrown/connector/cosmosis/likelihood.py index c7446df3..9163ceaf 100644 --- a/firecrown/connector/cosmosis/likelihood.py +++ b/firecrown/connector/cosmosis/likelihood.py @@ -11,7 +11,6 @@ import cosmosis.datablock from cosmosis.datablock import option_section from cosmosis.datablock import names as section_names -import pyccl as ccl from firecrown.connector.mapping import mapping_builder, MappingCosmoSIS from firecrown.likelihood.gaussfamily import GaussFamily @@ -52,17 +51,16 @@ def __init__(self, config: cosmosis.datablock) -> None: if likelihood_source == "": likelihood_source = config[option_section, "firecrown_config"] - require_nonlinear_pk = config.get_bool( - option_section, "require_nonlinear_pk", False - ) - build_parameters = extract_section(config, option_section) - sections = config.get_string(option_section, "sampling_parameters_sections", "") - sections = sections.split() + sections_str: str = config.get_string( + option_section, "sampling_parameters_sections", "" + ) + assert isinstance(sections_str, str) + sections = sections_str.split() self.firecrown_module_name = option_section - self.sampling_sections = sections + self.sampling_sections: list[str] = sections self.likelihood: Likelihood try: self.likelihood, self.tools = load_likelihood( @@ -75,9 +73,7 @@ def __init__(self, config: cosmosis.datablock) -> None: raise # We have to do some extra type-fiddling here because mapping_builder # has a declared return type of the base class. - new_mapping = mapping_builder( - input_style="CosmoSIS", require_nonlinear_pk=require_nonlinear_pk - ) + new_mapping = mapping_builder(input_style="CosmoSIS") assert isinstance(new_mapping, MappingCosmoSIS) self.map = new_mapping @@ -88,6 +84,8 @@ def __init__(self, config: cosmosis.datablock) -> None: required_parameters = ( self.likelihood.required_parameters() + self.tools.required_parameters() ) + required_parameters -= self.tools.ccl_factory.required_parameters() + if len(required_parameters) != 0: msg = ( f"The configured likelihood has required " @@ -111,10 +109,7 @@ def execute(self, sample: cosmosis.datablock) -> int: sample, "cosmological_parameters" ) self.map.set_params_from_cosmosis(cosmological_params) - - ccl_cosmo = ccl.CosmologyCalculator( - **self.map.asdict(), **self.map.calculate_ccl_args(sample) - ) + ccl_args = self.map.calculate_ccl_args(sample) # TODO: Future development will need to capture elements that get put into the # datablock. This probably will be in a different "physics module" and not in @@ -122,9 +117,11 @@ def execute(self, sample: cosmosis.datablock) -> int: # calculations. e.g., data_vector/firecrown_theory data_vector/firecrown_data firecrown_params = self.calculate_firecrown_params(sample) + firecrown_params = ParamsMap(firecrown_params | self.map.asdict()) + firecrown_params.use_lower_case_keys(True) self.update_likelihood_and_tools(firecrown_params) - self.tools.prepare(ccl_cosmo) + self.tools.prepare(calculator_args=ccl_args) loglike = self.likelihood.compute_loglike(self.tools) derived_params_collection = self.likelihood.get_derived_parameters() diff --git a/firecrown/connector/mapping.py b/firecrown/connector/mapping.py index ab1a61e0..dc2ff5a3 100644 --- a/firecrown/connector/mapping.py +++ b/firecrown/connector/mapping.py @@ -9,7 +9,7 @@ import typing import warnings from abc import ABC -from typing import Any, Type, final +from typing import Type, final import cosmosis.datablock import numpy as np @@ -18,6 +18,7 @@ from firecrown.descriptors import TypeFloat, TypeString from firecrown.likelihood.likelihood import NamedParameters +from firecrown.ccl_factory import CCLCalculatorArgs, PowerSpec, Background def build_ccl_background_dict( @@ -25,7 +26,7 @@ def build_ccl_background_dict( a: npt.NDArray[np.float64], chi: npt.NDArray[np.float64], h_over_h0: npt.NDArray[np.float64], -) -> dict[str, npt.NDArray[np.float64]]: +) -> Background: """Builds the CCL dictionary of background quantities. :param a: The scale factor array @@ -66,13 +67,8 @@ class Mapping(ABC): wa = TypeFloat() T_CMB = TypeFloat() - def __init__(self, *, require_nonlinear_pk: bool = False) -> None: - """Initialize the Mapping object. - - :param require_nonlinear_pk: Whether the mapping requires the - non-linear power spectrum - """ - self.require_nonlinear_pk = require_nonlinear_pk + def __init__(self) -> None: + """Initialize the Mapping object.""" self.m_nu: float | list[float] | None = None def get_params_names(self) -> list[str]: @@ -139,7 +135,6 @@ def set_params( Omega_k: float, Neff: float, m_nu: float | list[float], - m_nu_type: str, w0: float, wa: float, T_CMB: float, @@ -158,7 +153,6 @@ def set_params( :param Omega_k: curvature of the universe :param Neff: effective number of relativistic neutrino species :param m_nu: effective mass of neutrinos - :param m_nu_type: type of massive neutrinos :param w0: constant of the CPL parameterization of the dark energy equation of state :param wa: linear coefficient of the CPL parameterization of the @@ -186,7 +180,6 @@ def set_params( self.Omega_g = None self.Neff = Neff self.m_nu = m_nu - self.m_nu_type = m_nu_type self.w0 = w0 self.wa = wa self.T_CMB = T_CMB @@ -220,7 +213,7 @@ def redshift_to_scale_factor_p_k( p_k_out = np.flipud(p_k) return p_k_out - def asdict(self) -> dict[str, None | float | list[float] | str]: + def asdict(self) -> dict[str, float | list[float]]: """Return a dictionary containing the cosmological constants. :return: the dictionary, containing keys: @@ -234,29 +227,34 @@ def asdict(self) -> dict[str, None | float | list[float] | str]: - ``Omega_k``: curvature of the universe - ``Neff``: effective number of relativistic neutrino species - ``m_nu``: effective mass of neutrinos - - ``m_nu_type``: type of massive neutrinos - ``w0``: constant of the CPL parameterization of the dark energy equation of state - ``wa``: linear coefficient of the CPL parameterization of the dark energy equation of state - ``T_CMB``: cosmic microwave background temperature today """ - return { + cosmo_dict: dict[str, float | list[float]] = { "Omega_c": self.Omega_c, "Omega_b": self.Omega_b, "h": self.h, - "A_s": self.A_s, - "sigma8": self.sigma8, "n_s": self.n_s, "Omega_k": self.Omega_k, - "Omega_g": self.Omega_g, "Neff": self.Neff, - "m_nu": self.m_nu, - "mass_split": self.m_nu_type, "w0": self.w0, "wa": self.wa, "T_CMB": self.T_CMB, } + if self.A_s is not None: + cosmo_dict["A_s"] = self.A_s + if self.sigma8 is not None: + cosmo_dict["sigma8"] = self.sigma8 + # Currently we do not support Omega_g + # if self.Omega_g is not None: + # cosmo_dict["Omega_g"] = self.Omega_g + if self.m_nu is not None: + cosmo_dict["m_nu"] = self.m_nu + + return cosmo_dict def get_H0(self) -> float: """Return the value of H0. @@ -346,7 +344,6 @@ def set_params_from_cosmosis(self, cosmosis_params: NamedParameters) -> None: Neff = delta_neff + 3.046 omega_nu = cosmosis_params.get_float("omega_nu") m_nu = omega_nu * h * h * 93.14 - m_nu_type = "normal" w0 = cosmosis_params.get_float("w") wa = cosmosis_params.get_float("wa") @@ -360,7 +357,6 @@ def set_params_from_cosmosis(self, cosmosis_params: NamedParameters) -> None: Omega_k=Omega_k, Neff=Neff, m_nu=m_nu, - m_nu_type=m_nu_type, w0=w0, wa=-wa, # Is this minus sign here correct? T_CMB=2.7255, @@ -368,13 +364,14 @@ def set_params_from_cosmosis(self, cosmosis_params: NamedParameters) -> None: ) # pylint: enable=duplicate-code - def calculate_ccl_args(self, sample: cosmosis.datablock) -> dict[str, Any]: + def calculate_ccl_args(self, sample: cosmosis.datablock) -> CCLCalculatorArgs: """Calculate the arguments necessary for CCL for this sample. :param sample: the datablock for the current sample :return: the arguments required by CCL """ - ccl_args: dict[str, Any] = {} + pk_linear: None | PowerSpec = None + pk_nonlin: None | PowerSpec = None if sample.has_section("matter_power_lin"): k = self.transform_k_h_to_k(sample["matter_power_lin", "k_h"]) z_mpl = sample["matter_power_lin", "z"] @@ -382,7 +379,7 @@ def calculate_ccl_args(self, sample: cosmosis.datablock) -> dict[str, Any]: p_k = self.transform_p_k_h3_to_p_k(sample["matter_power_lin", "p_k"]) p_k = self.redshift_to_scale_factor_p_k(p_k) - ccl_args["pk_linear"] = { + pk_linear = { "a": scale_mpl, "k": k, "delta_matter:delta_matter": p_k, @@ -395,15 +392,11 @@ def calculate_ccl_args(self, sample: cosmosis.datablock) -> dict[str, Any]: p_k = self.transform_p_k_h3_to_p_k(sample["matter_power_nl", "p_k"]) p_k = self.redshift_to_scale_factor_p_k(p_k) - ccl_args["pk_nonlin"] = { + pk_nonlin = { "a": scale_mpl, "k": k, "delta_matter:delta_matter": p_k, } - elif self.require_nonlinear_pk: - ccl_args["nonlinear_model"] = "halofit" - else: - ccl_args["nonlinear_model"] = None # TODO: We should have several configurable modes for this module. # In all cases, an exception will be raised (causing a program @@ -437,9 +430,15 @@ def calculate_ccl_args(self, sample: cosmosis.datablock) -> dict[str, Any]: # h_over_h0 = np.flip(sample["distances", "h"]) * hubble_radius_today h_over_h0 = self.transform_h_to_h_over_h0(sample["distances", "h"]) - ccl_args["background"] = build_ccl_background_dict( + background: Background = build_ccl_background_dict( a=scale_distances, chi=chi, h_over_h0=h_over_h0 ) + ccl_args: CCLCalculatorArgs = {"background": background} + + if pk_linear is not None: + ccl_args.update({"pk_linear": pk_linear}) + if pk_nonlin is not None: + ccl_args.update({"pk_nonlin": pk_nonlin}) return ccl_args @@ -487,7 +486,6 @@ def set_params_from_camb(self, **params_values) -> None: m_nu = params_values["mnu"] Omega_k0 = params_values["omk"] - m_nu_type = "normal" h0 = H0 / 100.0 h02 = h0 * h0 Omega_b0 = ombh2 / h02 @@ -513,7 +511,6 @@ def set_params_from_camb(self, **params_values) -> None: sigma8=None, A_s=As, m_nu=m_nu, - m_nu_type=m_nu_type, w0=w, wa=wa, Neff=Neff, diff --git a/firecrown/connector/numcosmo/numcosmo.py b/firecrown/connector/numcosmo/numcosmo.py index 0f88d04c..0e39f934 100644 --- a/firecrown/connector/numcosmo/numcosmo.py +++ b/firecrown/connector/numcosmo/numcosmo.py @@ -4,9 +4,9 @@ be used without an installation of NumCosmo. """ -from typing import Any +import warnings + import numpy as np -import pyccl as ccl from numcosmo_py import Nc, Ncm, GObject, var_dict_to_dict, dict_to_var_dict @@ -17,6 +17,12 @@ from firecrown.parameters import ParamsMap from firecrown.connector.mapping import Mapping, build_ccl_background_dict from firecrown.modeling_tools import ModelingTools +from firecrown.ccl_factory import ( + CCLCalculatorArgs, + PowerSpec, + CCLFactory, + PoweSpecAmplitudeParameter, +) def get_hiprim(hi_cosmo: Nc.HICosmo) -> Nc.HIPrimPowerLaw: @@ -48,27 +54,33 @@ class MappingNumCosmo(GObject.Object): def __init__( self, - require_nonlinear_pk: bool = False, + require_nonlinear_pk: None | bool = None, p_ml: None | Nc.PowspecML = None, p_mnl: None | Nc.PowspecMNL = None, dist: None | Nc.Distance = None, ) -> None: """Initialize a MappingNumCosmo object. - :param require_nonlinear_pk: whether to require a nonlinear power spectrum :param p_ml: optional PowspecML object :param p_mnl: optional PowspecMNL object :param dist: optional Distance object """ - super().__init__( # type: ignore - require_nonlinear_pk=require_nonlinear_pk, p_ml=p_ml, p_mnl=p_mnl, dist=dist - ) + super().__init__(p_ml=p_ml, p_mnl=p_mnl, dist=dist) # type: ignore self.mapping: Mapping self._mapping_name: str self._p_ml: None | Nc.PowspecML self._p_mnl: None | Nc.PowspecMNL self._dist: Nc.Distance + if require_nonlinear_pk is not None: + warnings.warn( + "The require_nonlinear_pk argument is deprecated and will be removed " + "in future versions. This configuration is now handled by the " + "likelihood factory function.", + DeprecationWarning, + stacklevel=2, + ) + if not hasattr(self, "_p_ml"): self._p_ml = None @@ -101,28 +113,6 @@ def _set_mapping_name(self, value: str) -> None: setter=_set_mapping_name, ) - def _get_require_nonlinear_pk(self) -> bool: - """Return whether nonlinear power spectra are required. - - :return: whether nonlinear power spectra are required - """ - return self.mapping.require_nonlinear_pk - - def _set_require_nonlinear_pk(self, value: bool) -> None: - """Set whether nonlinear power spectra are required. - - :param value: whether nonlinear power spectra are required - """ - self.mapping.require_nonlinear_pk = value - - require_nonlinear_pk = GObject.Property( - type=bool, - default=False, - flags=GObject.ParamFlags.READWRITE, - getter=_get_require_nonlinear_pk, - setter=_set_require_nonlinear_pk, - ) - def _get_p_ml(self) -> None | Nc.PowspecML: """Return the NumCosmo PowspecML object. @@ -184,7 +174,7 @@ def _set_dist(self, value: Nc.Distance) -> None: ) def set_params_from_numcosmo( - self, mset: Ncm.MSet + self, mset: Ncm.MSet, ccl_factory: CCLFactory ) -> None: # pylint: disable-msg=too-many-locals """Set the parameters of the contained Mapping object. @@ -207,10 +197,7 @@ def set_params_from_numcosmo( T_gamma0 = hi_cosmo.T_gamma0() m_nu: float | list[float] = 0.0 - if hi_cosmo.NMassNu() == 0: - m_nu_type = "normal" - else: - m_nu_type = "list" + if hi_cosmo.NMassNu() > 0: assert hi_cosmo.NMassNu() <= 3 m_nu = [hi_cosmo.MassNuInfo(i)[0] for i in range(hi_cosmo.NMassNu())] @@ -224,19 +211,31 @@ def set_params_from_numcosmo( case _: raise ValueError(f"NumCosmo object {type(hi_cosmo)} not supported.") - hiprim = get_hiprim(hi_cosmo) + A_s = None + sigma8 = None + match ccl_factory.amplitude_parameter: + case PoweSpecAmplitudeParameter.SIGMA8: + if self._p_ml is None: + raise ValueError( + "PowspecML object must be provided when using sigma8." + ) + sigma8 = self._p_ml.sigma_tophat_R(hi_cosmo, 1.0e-7, 0.0, 8.0 / h) + case PoweSpecAmplitudeParameter.AS: + A_s = get_hiprim(hi_cosmo).SA_Ampl() + + assert (A_s is not None) or (sigma8 is not None) # pylint: disable=duplicate-code self.mapping.set_params( Omega_c=Omega_c, Omega_b=Omega_b, h=h, - A_s=hiprim.SA_Ampl(), - n_s=hiprim.props.n_SA, + A_s=A_s, + sigma8=sigma8, + n_s=get_hiprim(hi_cosmo).props.n_SA, Omega_k=Omega_k, Neff=Neff, m_nu=m_nu, - m_nu_type=m_nu_type, w0=w0, wa=wa, T_CMB=T_gamma0, @@ -245,13 +244,14 @@ def set_params_from_numcosmo( def calculate_ccl_args( # pylint: disable-msg=too-many-locals self, mset: Ncm.MSet - ) -> dict[str, Any]: + ) -> CCLCalculatorArgs: """Calculate the arguments necessary for CCL for this sample. :param mset: the NumCosmo MSet object from which to get the parameters :return: a dictionary of the arguments required by CCL """ - ccl_args: dict[str, Any] = {} + pk_linear: None | PowerSpec = None + pk_nonlin: None | PowerSpec = None hi_cosmo = mset.peek(Nc.HICosmo.id()) assert isinstance(hi_cosmo, Nc.HICosmo) @@ -265,8 +265,7 @@ def calculate_ccl_args( # pylint: disable-msg=too-many-locals np.array(p_m_spline.peek_zm().dup_array()).reshape(len(k), len(z)) ) p_k = self.mapping.redshift_to_scale_factor_p_k(p_k) - - ccl_args["pk_linear"] = { + pk_linear = { "a": scale, "k": k, "delta_matter:delta_matter": p_k, @@ -282,16 +281,11 @@ def calculate_ccl_args( # pylint: disable-msg=too-many-locals np.array(p_mnl_spline.peek_zm().dup_array()).reshape(len(k), len(z)) ) p_mnl = self.mapping.redshift_to_scale_factor_p_k(p_mnl) - - ccl_args["pk_nonlin"] = { + pk_nonlin = { "a": scale_mpnl, "k": k, "delta_matter:delta_matter": p_mnl, } - elif self.mapping.require_nonlinear_pk: - ccl_args["nonlinear_model"] = "halofit" - else: - ccl_args["nonlinear_model"] = None d_spline = self._dist.comoving_distance_spline.peek_spline() z_dist = np.array(d_spline.get_xv().dup_array()) @@ -309,15 +303,16 @@ def calculate_ccl_args( # pylint: disable-msg=too-many-locals chi = chi[a_unique_indices] h_over_h0 = h_over_h0[a_unique_indices] - ccl_args["background"] = build_ccl_background_dict( - a=scale_distances, chi=chi, h_over_h0=h_over_h0 - ) - - ccl_args["background"] = { - "a": scale_distances, - "chi": chi, - "h_over_h0": h_over_h0, + ccl_args: CCLCalculatorArgs = { + "background": build_ccl_background_dict( + a=scale_distances, chi=chi, h_over_h0=h_over_h0 + ) } + if pk_linear: + ccl_args["pk_linear"] = pk_linear + if pk_nonlin: + ccl_args["pk_nonlin"] = pk_nonlin + return ccl_args def create_params_map(self, model_list: list[str], mset: Ncm.MSet) -> ParamsMap: @@ -341,15 +336,22 @@ def create_params_map(self, model_list: list[str], mset: Ncm.MSet) -> ParamsMap: model_dict = { param: model.param_get_by_name(param) for param in param_names } - shared_keys = set(model_dict).intersection(params_map) - if len(shared_keys) > 0: - raise RuntimeError( - f"The following keys `{shared_keys}` appear " - f"in more than one model used by the " - f"module {model_list}." - ) - params_map = ParamsMap({**params_map, **model_dict}) + params_map = self._update_params_map(model_list, params_map, model_dict) + params_map = self._update_params_map( + model_list, params_map, self.mapping.asdict() + ) + + return params_map + + def _update_params_map(self, model_list, params_map, model_dict): + shared_keys = set(model_dict).intersection(params_map) + if len(shared_keys) > 0: + raise RuntimeError( + f"The following keys `{shared_keys}` appear in more than one model " + f"used by the module {model_list} or cosmological parameters." + ) + params_map = ParamsMap({**params_map, **model_dict}) return params_map @@ -370,7 +372,6 @@ def __init__(self): super().__init__() self.likelihood: Likelihood self.tools: ModelingTools - self.ccl_cosmo: None | ccl.Cosmology = None self._model_list: list[str] self._nc_mapping: MappingNumCosmo self._likelihood_source: None | str = None @@ -566,16 +567,13 @@ def do_prepare( # pylint: disable-msg=arguments-differ self.likelihood.reset() self.tools.reset() - self._nc_mapping.set_params_from_numcosmo(mset) + self._nc_mapping.set_params_from_numcosmo(mset, self.tools.ccl_factory) ccl_args = self._nc_mapping.calculate_ccl_args(mset) - self.ccl_cosmo = ccl.CosmologyCalculator( - **self._nc_mapping.mapping.asdict(), **ccl_args - ) params_map = self._nc_mapping.create_params_map(self.model_list, mset) self.likelihood.update(params_map) self.tools.update(params_map) - self.tools.prepare(self.ccl_cosmo) + self.tools.prepare(calculator_args=ccl_args) def do_m2lnL_val(self, _) -> float: # pylint: disable-msg=arguments-differ """Implements the virtual method `m2lnL`. @@ -615,7 +613,6 @@ def __init__(self) -> None: super().__init__() self.likelihood: ConstGaussian self.tools: ModelingTools - self.ccl_cosmo: ccl.Cosmology self.dof: int self.len: int self._model_list: list[str] @@ -675,7 +672,6 @@ def _configure_object(self) -> None: assert nrows == ncols self.set_size(nrows) - self.ccl_cosmo = None self.dof = nrows self.len = nrows self.peek_cov().set_from_array( # pylint: disable-msg=no-member @@ -843,17 +839,13 @@ def do_prepare( # pylint: disable-msg=arguments-differ self.likelihood.reset() self.tools.reset() - self._nc_mapping.set_params_from_numcosmo(mset) + self._nc_mapping.set_params_from_numcosmo(mset, self.tools.ccl_factory) ccl_args = self._nc_mapping.calculate_ccl_args(mset) - - self.ccl_cosmo = ccl.CosmologyCalculator( - **self._nc_mapping.mapping.asdict(), **ccl_args - ) params_map = self._nc_mapping.create_params_map(self._model_list, mset) self.likelihood.update(params_map) self.tools.update(params_map) - self.tools.prepare(self.ccl_cosmo) + self.tools.prepare(calculator_args=ccl_args) # pylint: disable-next=arguments-differ def do_mean_func(self, _, vp) -> None: diff --git a/firecrown/likelihood/factories.py b/firecrown/likelihood/factories.py new file mode 100644 index 00000000..a00510a9 --- /dev/null +++ b/firecrown/likelihood/factories.py @@ -0,0 +1,224 @@ +""" +Factory functions for creating likelihoods from SACC files. + +This module provides factory functions to create likelihood objects by combining a SACC +file and a set of statistic factories. Users can define their own custom statistic +factories for advanced use cases or rely on the generic factory functions provided here +for simpler scenarios. + +For straightforward contexts where all data in the SACC file is utilized, the generic +factories simplify the process. The user only needs to supply the SACC file and specify +which statistic factories to use, and the likelihood factory will handle the creation of +the likelihood object, assembling the necessary components automatically. + +These functions are particularly useful when the full set of statistics present in a +SACC file is being used without the need for complex customization. +""" + +from typing import Annotated +from enum import Enum, auto +from pathlib import Path + +import yaml +from pydantic import BaseModel, ConfigDict, BeforeValidator + +import sacc +from firecrown.likelihood.likelihood import Likelihood, NamedParameters +from firecrown.likelihood.gaussian import ConstGaussian +from firecrown.likelihood.weak_lensing import WeakLensingFactory +from firecrown.likelihood.number_counts import NumberCountsFactory +from firecrown.likelihood.two_point import TwoPoint +from firecrown.data_functions import ( + extract_all_real_data, + extract_all_harmonic_data, + check_two_point_consistence_real, + check_two_point_consistence_harmonic, +) +from firecrown.modeling_tools import ModelingTools +from firecrown.utils import YAMLSerializable + + +class TwoPointCorrelationSpace(YAMLSerializable, str, Enum): + """This class defines the two-point correlation space. + + The two-point correlation space can be either real or harmonic. The real space + corresponds measurements in terms of angular separation, while the harmonic space + corresponds to measurements in terms of spherical harmonics decomposition. + """ + + @staticmethod + def _generate_next_value_(name, _start, _count, _last_values): + return name.lower() + + REAL = auto() + HARMONIC = auto() + + +def _validate_correlation_space(value): + if isinstance(value, str): + try: + return TwoPointCorrelationSpace(value) # Convert from string to Enum + except ValueError as exc: + raise ValueError( + f"Invalid value for TwoPointCorrelationSpace: {value}" + ) from exc + return value + + +class TwoPointFactory(BaseModel): + """Factory class for WeakLensing objects.""" + + model_config = ConfigDict(extra="forbid", frozen=True) + + correlation_space: Annotated[ + TwoPointCorrelationSpace, BeforeValidator(_validate_correlation_space) + ] + weak_lensing_factory: WeakLensingFactory + number_counts_factory: NumberCountsFactory + + def model_post_init(self, __context) -> None: + """Initialize the WeakLensingFactory object.""" + + +class DataSourceSacc(BaseModel): + """Model for the data source in a likelihood configuration.""" + + sacc_data_file: str + + def model_post_init(self, __context) -> None: + """Initialize the DataSourceSacc object.""" + sacc_data_file = Path(self.sacc_data_file) + if not sacc_data_file.exists(): + raise FileNotFoundError(f"File {sacc_data_file} does not exist") + + def get_sacc_data(self) -> sacc.Sacc: + """Load the SACC data file.""" + return sacc.Sacc.load_fits(self.sacc_data_file) + + +class TwoPointExperiment(BaseModel): + """Model for the two-point experiment in a likelihood configuration.""" + + two_point_factory: TwoPointFactory + data_source: DataSourceSacc + + def model_post_init(self, __context) -> None: + """Initialize the TwoPointExperiment object.""" + + +def build_two_point_likelihood( + build_parameters: NamedParameters, +) -> tuple[Likelihood, ModelingTools]: + """ + Build a likelihood object for two-point statistics from a SACC file. + + This function creates a likelihood object for two-point statistics using a SACC file + and a set of statistic factories. The user must provide the SACC file and specify + which statistic factories to use. The likelihood object is created by combining the + SACC file with the specified statistic factories. + + :param build_parameters: A NamedParameters object containing the following + parameters: + - sacc_file: The SACC file containing the data. + - statistic_factories: A YAML file containing the statistic factories to use. + """ + likelihood_config_file = build_parameters.get_string("likelihood_config") + + with open(likelihood_config_file, "r", encoding="utf-8") as f: + likelihood_config = yaml.safe_load(f) + + if likelihood_config is None: + raise ValueError("No likelihood config found.") + + exp = TwoPointExperiment.model_validate(likelihood_config, strict=True) + modeling_tools = ModelingTools() + + # Load the SACC file + sacc_data = exp.data_source.get_sacc_data() + + match exp.two_point_factory.correlation_space: + case TwoPointCorrelationSpace.REAL: + likelihood = _build_two_point_likelihood_real( + sacc_data, + exp.two_point_factory.weak_lensing_factory, + exp.two_point_factory.number_counts_factory, + ) + case TwoPointCorrelationSpace.HARMONIC: + likelihood = _build_two_point_likelihood_harmonic( + sacc_data, + exp.two_point_factory.weak_lensing_factory, + exp.two_point_factory.number_counts_factory, + ) + + return likelihood, modeling_tools + + +def _build_two_point_likelihood_harmonic( + sacc_data: sacc.Sacc, + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +): + """ + Build a likelihood object for two-point statistics in harmonic space. + + This function creates a likelihood object for two-point statistics in harmonic space + using a SACC file and a set of statistic factories. The user must provide the SACC + file and specify which statistic factories to use. The likelihood object is created + by combining the SACC file with the specified statistic factories. + + :param sacc_data: The SACC file containing the data. + :param wl_factory: The weak lensing statistic factory. + :param nc_factory: The number counts statistic factory. + + :return: A likelihood object for two-point statistics in harmonic space. + """ + tpms = extract_all_harmonic_data(sacc_data) + if len(tpms) == 0: + raise ValueError( + "No two-point measurements in harmonic space found in the SACC file." + ) + + check_two_point_consistence_harmonic(tpms) + + two_points = TwoPoint.from_measurement( + tpms, wl_factory=wl_factory, nc_factory=nc_factory + ) + + likelihood = ConstGaussian.create_ready(two_points, sacc_data.covariance.dense) + + return likelihood + + +def _build_two_point_likelihood_real( + sacc_data: sacc.Sacc, + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +): + """ + Build a likelihood object for two-point statistics in real space. + + This function creates a likelihood object for two-point statistics in real space + using a SACC file and a set of statistic factories. The user must provide the SACC + file and specify which statistic factories to use. The likelihood object is created + by combining the SACC file with the specified statistic factories. + + :param sacc_data: The SACC file containing the data. + :param wl_factory: The weak lensing statistic factory. + :param nc_factory: The number counts statistic factory. + + :return: A likelihood object for two-point statistics in real space. + """ + tpms = extract_all_real_data(sacc_data) + if len(tpms) == 0: + raise ValueError( + "No two-point measurements in real space found in the SACC file." + ) + check_two_point_consistence_real(tpms) + + two_points = TwoPoint.from_measurement( + tpms, wl_factory=wl_factory, nc_factory=nc_factory + ) + + likelihood = ConstGaussian.create_ready(two_points, sacc_data.covariance.dense) + + return likelihood diff --git a/firecrown/likelihood/likelihood.py b/firecrown/likelihood/likelihood.py index 2aa75155..9059f3ba 100644 --- a/firecrown/likelihood/likelihood.py +++ b/firecrown/likelihood/likelihood.py @@ -321,7 +321,9 @@ def convert_to_basic_dict( def load_likelihood_from_module_type( - module: types.ModuleType, build_parameters: NamedParameters + module: types.ModuleType, + build_parameters: NamedParameters, + build_likelihood_name: str = "build_likelihood", ) -> tuple[Likelihood, ModelingTools]: """Loads a likelihood from a module type. @@ -337,28 +339,29 @@ def load_likelihood_from_module_type( function parameters :return : a tuple of the likelihood and the modeling tools """ - if not hasattr(module, "build_likelihood"): + if not hasattr(module, build_likelihood_name): if not hasattr(module, "likelihood"): raise AttributeError( f"Firecrown initialization module {module.__name__} in " f"{module.__file__} does not define " - f"a `build_likelihood` factory function." + f"a `{build_likelihood_name}` factory function." ) warnings.warn( - "The use of a likelihood variable in Firecrown's initialization " - "module is deprecated. Any parameters passed to the likelihood " - "will be ignored. The module should define a `build_likelihood` " - "factory function.", + f"The use of a likelihood variable in Firecrown's initialization " + f"module is deprecated. Any parameters passed to the likelihood " + f"will be ignored. The module should define the `{build_likelihood_name}` " + f"factory function.", category=DeprecationWarning, ) likelihood = module.likelihood tools = ModelingTools() else: - if not callable(module.build_likelihood): + build_likelihood = getattr(module, build_likelihood_name) + if not callable(build_likelihood): raise TypeError( - "The factory function `build_likelihood` must be a callable." + f"The factory function `{build_likelihood_name}` must be a callable." ) - build_return = module.build_likelihood(build_parameters) + build_return = build_likelihood(build_parameters) if isinstance(build_return, tuple): likelihood, tools = build_return else: @@ -443,13 +446,26 @@ def load_likelihood_from_module( :return : a tuple of the likelihood and the modeling tools """ try: - mod = importlib.import_module(module) + # Try importing the entire string as a module first + try: + mod = importlib.import_module(module) + func = "build_likelihood" + except ImportError as sub_exc: + # If it fails, split and try importing as module.function + if "." not in module: + raise sub_exc + module_name, func = module.rsplit(".", 1) + mod = importlib.import_module(module_name) except ImportError as exc: raise ValueError( - f"Unrecognized Firecrown initialization module {module}." + f"Unrecognized Firecrown initialization module '{module}'. " + f"The module must be either a module_name or a module_name.func " + f"where func is the factory function." ) from exc - return load_likelihood_from_module_type(mod, build_parameters) + return load_likelihood_from_module_type( + mod, build_parameters, build_likelihood_name=func + ) def load_likelihood( diff --git a/firecrown/modeling_tools.py b/firecrown/modeling_tools.py index 35e94796..9a53eb6a 100644 --- a/firecrown/modeling_tools.py +++ b/firecrown/modeling_tools.py @@ -15,6 +15,7 @@ from firecrown.models.cluster.abundance import ClusterAbundance from firecrown.updatable import Updatable, UpdatableCollection +from firecrown.ccl_factory import CCLFactory, CCLCalculatorArgs class ModelingTools(Updatable): @@ -30,6 +31,7 @@ def __init__( pt_calculator: None | pyccl.nl_pt.EulerianPTCalculator = None, pk_modifiers: None | Collection[PowerspectrumModifier] = None, cluster_abundance: None | ClusterAbundance = None, + ccl_factory: None | CCLFactory = None, ): super().__init__() self.ccl_cosmo: None | pyccl.Cosmology = None @@ -39,6 +41,7 @@ def __init__( self.powerspectra: dict[str, pyccl.Pk2D] = {} self._prepared: bool = False self.cluster_abundance = cluster_abundance + self.ccl_factory = CCLFactory() if ccl_factory is None else ccl_factory def add_pk(self, name: str, powerspectrum: pyccl.Pk2D) -> None: """Add a :python:`pyccl.Pk2D` to the table of power spectra.""" @@ -71,7 +74,7 @@ def has_pk(self, name: str) -> bool: return False return True - def prepare(self, ccl_cosmo: pyccl.Cosmology) -> None: + def prepare(self, *, calculator_args: None | CCLCalculatorArgs = None) -> None: """Prepare the Cosmology for use in likelihoods. This method will prepare the ModelingTools for use in likelihoods. This @@ -88,16 +91,17 @@ def prepare(self, ccl_cosmo: pyccl.Cosmology) -> None: if self.ccl_cosmo is not None: raise RuntimeError("Cosmology has already been set") - self.ccl_cosmo = ccl_cosmo + + self.ccl_cosmo = self.ccl_factory.create(calculator_args) if self.pt_calculator is not None: - self.pt_calculator.update_ingredients(ccl_cosmo) + self.pt_calculator.update_ingredients(self.ccl_cosmo) for pkm in self.pk_modifiers: self.add_pk(name=pkm.name, powerspectrum=pkm.compute_p_of_k_z(tools=self)) if self.cluster_abundance is not None: - self.cluster_abundance.update_ingredients(ccl_cosmo) + self.cluster_abundance.update_ingredients(self.ccl_cosmo) self._prepared = True diff --git a/firecrown/parameters.py b/firecrown/parameters.py index 81b6d9f6..79f017af 100644 --- a/firecrown/parameters.py +++ b/firecrown/parameters.py @@ -35,6 +35,28 @@ def parameter_get_full_name(prefix: None | str, param: str) -> str: return param +def _validade_params_map_value(name: str, value: float | list[float]) -> None: + """Check if the value is a float or a list of floats. + + Raises a TypeError if the value is not a float or a list of floats. + + :param name: name of the parameter + :param value: value to be checked + """ + if not isinstance(value, (float, list)): + raise TypeError( + f"Value for parameter {name} is not a float or a list of floats: " + f"{type(value)}" + ) + + if isinstance(value, list): + if not all(isinstance(v, float) for v in value): + raise TypeError( + f"Value for parameter {name} is not a float or a list of floats: " + f"{type(value)}" + ) + + class ParamsMap(dict[str, float]): """A specialized dict in which all keys are strings and values are floats. @@ -44,6 +66,9 @@ class ParamsMap(dict[str, float]): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + for name, value in self.items(): + _validade_params_map_value(name, value) + self.lower_case: bool = False def use_lower_case_keys(self, enable: bool) -> None: @@ -62,11 +87,14 @@ def get_from_full_name(self, full_name: str) -> float: Raises a KeyError if the parameter is not found. """ - if self.lower_case: - full_name = full_name.lower() - if full_name in self.keys(): return self[full_name] + + if self.lower_case: + full_name_lower = full_name.lower() + if full_name_lower in self.keys(): + return self[full_name_lower] + raise KeyError(f"Key {full_name} not found.") def get_from_prefix_param(self, prefix: None | str, param: str) -> float: @@ -109,6 +137,14 @@ def __add__(self, other: RequiredParameters) -> RequiredParameters: """ return RequiredParameters(self.params_set | other.params_set) + def __sub__(self, other: RequiredParameters) -> RequiredParameters: + """Return a new RequiredParameters with the names in self but not in other. + + Note that this function returns a new object that does not share state + with either argument to the subtraction operator. + """ + return RequiredParameters(self.params_set - other.params_set) + def __eq__(self, other: object): """Compare two RequiredParameters objects for equality. diff --git a/firecrown/updatable.py b/firecrown/updatable.py index ff11af30..ed7136f6 100644 --- a/firecrown/updatable.py +++ b/firecrown/updatable.py @@ -46,9 +46,9 @@ def __init__(self, parameter: str) -> None: self.parameter = parameter msg = ( f"The parameter `{parameter}` is required to update " - "something in this likelihood.\nIt should have been supplied" - "by the sampling framework.\nThe object being updated was:\n" - "{context}\n" + f"something in this likelihood.\nIt should have been supplied " + f"by the sampling framework.\nThe object being updated was:\n" + f"{self}\n" ) super().__init__(msg) @@ -351,9 +351,9 @@ def __init__(self, iterable: None | Iterable[T] = None) -> None: super().__init__(iterable) self._updated: bool = False for item in self: - if not isinstance(item, Updatable): + if not isinstance(item, (Updatable | UpdatableCollection)): raise TypeError( - "All the items in an UpdatableCollection must be updatable" + f"All the items in an UpdatableCollection must be updatable {item}" ) @final @@ -435,3 +435,25 @@ def __setitem__(self, key, value): ) super().__setitem__(key, cast(T, value)) + + +def get_default_params(*args: Updatable) -> dict[str, float]: + """Get a ParamsMap with the default values of all parameters in the updatables. + + :param args: updatables to get the default parameters from + :return: a ParamsMap with the default values of all parameters + """ + updatable_collection = UpdatableCollection(args) + required_parameters = updatable_collection.required_parameters() + default_parameters = required_parameters.get_default_values() + + return default_parameters + + +def get_default_params_map(*args: Updatable) -> ParamsMap: + """Get a ParamsMap with the default values of all parameters in the updatables. + + :param args: updatables to get the default parameters from + :return: a ParamsMap with the default values of all parameters + """ + return ParamsMap(get_default_params(*args)) diff --git a/tests/conftest.py b/tests/conftest.py index df9789cc..d135c1f4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,8 +11,8 @@ import numpy as np import numpy.typing as npt -import pyccl +from firecrown.updatable import get_default_params_map from firecrown.utils import upper_triangle_indices from firecrown.likelihood.statistic import TrivialStatistic from firecrown.parameters import ParamsMap @@ -119,7 +119,6 @@ def fixture_mapping_cosmosis() -> MappingCosmoSIS: Omega_k=0.0, Neff=3.046, m_nu=0.0, - m_nu_type="normal", w0=-1.0, wa=0.0, T_CMB=2.7255, @@ -156,8 +155,9 @@ def fixture_tools_with_vanilla_cosmology(): """Return a ModelingTools object containing the LCDM cosmology from pyccl.""" result = ModelingTools() - result.update(ParamsMap()) - result.prepare(pyccl.CosmologyVanillaLCDM()) + params = get_default_params_map(result) + result.update(params) + result.prepare() return result diff --git a/tests/connector/cobaya/test_model_likelihood.py b/tests/connector/cobaya/test_model_likelihood.py index c6446544..609add67 100644 --- a/tests/connector/cobaya/test_model_likelihood.py +++ b/tests/connector/cobaya/test_model_likelihood.py @@ -1,7 +1,6 @@ """Unit tests for the cobaya Mapping connector.""" import pytest -import pyccl as ccl from cobaya.model import get_model, Model from cobaya.log import LoggedError from firecrown.connector.cobaya.ccl import CCLConnector @@ -69,7 +68,7 @@ def test_cobaya_ccl_model(fiducial_params): "likelihood": { "test_lk": { "external": lambda _self=None: 0.0, - "requires": {"pyccl": None}, + "requires": {"pyccl_args": None, "pyccl_params": None}, } }, "theory": { @@ -82,20 +81,21 @@ def test_cobaya_ccl_model(fiducial_params): assert isinstance(model_fiducial, Model) model_fiducial.logposterior({}) - cosmo = model_fiducial.provider.get_pyccl() - assert isinstance(cosmo, ccl.Cosmology) + cosmo_args = model_fiducial.provider.get_pyccl_args() + cosmo_params = model_fiducial.provider.get_pyccl_params() + assert isinstance(cosmo_args, dict) h = fiducial_params["H0"] / 100.0 - assert cosmo["H0"] == pytest.approx(fiducial_params["H0"], rel=1.0e-5) - assert cosmo["Omega_c"] == pytest.approx( + assert cosmo_params["h"] * 100.0 == pytest.approx(fiducial_params["H0"], rel=1.0e-5) + assert cosmo_params["Omega_c"] == pytest.approx( fiducial_params["omch2"] / h**2, rel=1.0e-5 ) - assert cosmo["Omega_b"] == pytest.approx( + assert cosmo_params["Omega_b"] == pytest.approx( fiducial_params["ombh2"] / h**2, rel=1.0e-5 ) - assert cosmo["Omega_k"] == pytest.approx(0.0, rel=1.0e-5) - assert cosmo["A_s"] == pytest.approx(fiducial_params["As"], rel=1.0e-5) - assert cosmo["n_s"] == pytest.approx(fiducial_params["ns"], rel=1.0e-5) + assert cosmo_params["Omega_k"] == pytest.approx(0.0, rel=1.0e-5) + assert cosmo_params["A_s"] == pytest.approx(fiducial_params["As"], rel=1.0e-5) + assert cosmo_params["n_s"] == pytest.approx(fiducial_params["ns"], rel=1.0e-5) # The following test fails because of we are using the default # neutrino hierarchy, which is normal, while CAMB depends on the # parameter which we do not have access to. diff --git a/tests/connector/cosmosis/test_cosmosis_module.py b/tests/connector/cosmosis/test_cosmosis_module.py index 43de9f90..dc4394b0 100644 --- a/tests/connector/cosmosis/test_cosmosis_module.py +++ b/tests/connector/cosmosis/test_cosmosis_module.py @@ -555,29 +555,3 @@ def test_mapping_cosmosis_pk_nonlin(mapping_cosmosis): mapping_cosmosis.transform_p_k_h3_to_p_k(matter_power_nl_p_k) ), ) - - -def test_mapping_cosmosis_pk_nonlin_nonlinear_model(mapping_cosmosis): - block = DataBlock() - - block.put_double_array_1d("distances", "d_m", np.geomspace(0.1, 10.0, 100)) - block.put_double_array_1d("distances", "z", np.linspace(0.0, 2.0, 10)) - block.put_double_array_1d("distances", "h", np.geomspace(0.1, 10.0, 100)) - - mapping_cosmosis.require_nonlinear_pk = True - ccl_args = mapping_cosmosis.calculate_ccl_args(block) - - assert "background" in ccl_args - assert "pk_linear" not in ccl_args - assert "pk_nonlin" not in ccl_args - assert "nonlinear_model" in ccl_args - assert ccl_args["nonlinear_model"] - - mapping_cosmosis.require_nonlinear_pk = False - ccl_args = mapping_cosmosis.calculate_ccl_args(block) - - assert "background" in ccl_args - assert "pk_linear" not in ccl_args - assert "pk_nonlin" not in ccl_args - assert "nonlinear_model" in ccl_args - assert not ccl_args["nonlinear_model"] diff --git a/tests/connector/numcosmo/test_numcosmo_connector.py b/tests/connector/numcosmo/test_numcosmo_connector.py index cc8ca58d..eae4a4e0 100644 --- a/tests/connector/numcosmo/test_numcosmo_connector.py +++ b/tests/connector/numcosmo/test_numcosmo_connector.py @@ -18,10 +18,7 @@ @pytest.fixture(name="factory_plain") def fixture_factory_plain(): """Create a NumCosmoFactory instance.""" - map_cosmo = MappingNumCosmo( - require_nonlinear_pk=True, - dist=Nc.Distance.new(6.0), - ) + map_cosmo = MappingNumCosmo(dist=Nc.Distance.new(6.0)) build_parameters = NamedParameters() return NumCosmoFactory( "tests/likelihood/lkdir/lkscript.py", @@ -34,10 +31,7 @@ def fixture_factory_plain(): @pytest.fixture(name="factory_const_gauss") def fixture_factory_const_gauss(): """Create a NumCosmoFactory instance.""" - map_cosmo = MappingNumCosmo( - require_nonlinear_pk=True, - dist=Nc.Distance.new(6.0), - ) + map_cosmo = MappingNumCosmo(dist=Nc.Distance.new(6.0)) build_parameters = NamedParameters() return NumCosmoFactory( "tests/likelihood/gauss_family/lkscript_const_gaussian.py", diff --git a/tests/connector/numcosmo/test_numcosmo_mapping.py b/tests/connector/numcosmo/test_numcosmo_mapping.py index a116ede2..13640d05 100644 --- a/tests/connector/numcosmo/test_numcosmo_mapping.py +++ b/tests/connector/numcosmo/test_numcosmo_mapping.py @@ -14,6 +14,7 @@ MappingNumCosmo, NumCosmoFactory, ) +from firecrown.ccl_factory import CCLFactory, PoweSpecAmplitudeParameter Ncm.cfg_init() @@ -24,11 +25,7 @@ def test_numcosmo_mapping_create_params_map_non_existing_model(): cosmo = Nc.HICosmoDEXcdm() - map_cosmo = MappingNumCosmo( - require_nonlinear_pk=True, - dist=Nc.Distance.new(6.0), - ) - + map_cosmo = MappingNumCosmo(dist=Nc.Distance.new(6.0)) mset = Ncm.MSet() mset.set(cosmo) @@ -45,11 +42,7 @@ def test_numcosmo_mapping_create_params_map_absent_model(): cosmo = Nc.HICosmoDEXcdm() - map_cosmo = MappingNumCosmo( - require_nonlinear_pk=True, - dist=Nc.Distance.new(6.0), - ) - + map_cosmo = MappingNumCosmo(dist=Nc.Distance.new(6.0)) mset = Ncm.MSet() mset.set(cosmo) @@ -65,11 +58,7 @@ def test_numcosmo_mapping_create_params_map_two_models_sharing_parameters(): with an existing type but not present in the model set.""" cosmo = Nc.HICosmoDEXcdm() - - map_cosmo = MappingNumCosmo( - require_nonlinear_pk=True, - dist=Nc.Distance.new(6.0), - ) + map_cosmo = MappingNumCosmo(dist=Nc.Distance.new(6.0)) mset = Ncm.MSet() mset.set(cosmo) @@ -153,16 +142,12 @@ def test_numcosmo_mapping_unsupported(): cosmo = Nc.HICosmoDEJbp() - map_cosmo = MappingNumCosmo( - require_nonlinear_pk=True, - dist=Nc.Distance.new(6.0), - ) - + map_cosmo = MappingNumCosmo(dist=Nc.Distance.new(6.0)) mset = Ncm.MSet() mset.set(cosmo) with pytest.raises(ValueError, match="NumCosmo object .* not supported."): - map_cosmo.set_params_from_numcosmo(mset) + map_cosmo.set_params_from_numcosmo(mset, CCLFactory()) def test_numcosmo_mapping_missing_hiprim(): @@ -170,18 +155,16 @@ def test_numcosmo_mapping_missing_hiprim(): cosmo = Nc.HICosmoDECpl() - map_cosmo = MappingNumCosmo( - require_nonlinear_pk=True, - dist=Nc.Distance.new(6.0), - ) - + map_cosmo = MappingNumCosmo(dist=Nc.Distance.new(6.0)) mset = Ncm.MSet() mset.set(cosmo) with pytest.raises( ValueError, match="NumCosmo object must include a HIPrim object." ): - map_cosmo.set_params_from_numcosmo(mset) + map_cosmo.set_params_from_numcosmo( + mset, CCLFactory(amplitude_parameter=PoweSpecAmplitudeParameter.AS) + ) def test_numcosmo_mapping_invalid_hiprim(): @@ -191,64 +174,16 @@ def test_numcosmo_mapping_invalid_hiprim(): prim = Nc.HIPrimAtan() cosmo.add_submodel(prim) - map_cosmo = MappingNumCosmo( - require_nonlinear_pk=True, - dist=Nc.Distance.new(6.0), - ) - + map_cosmo = MappingNumCosmo(dist=Nc.Distance.new(6.0)) mset = Ncm.MSet() mset.set(cosmo) with pytest.raises( ValueError, match="NumCosmo HIPrim object type .* not supported." ): - map_cosmo.set_params_from_numcosmo(mset) - - -def test_numcosmo_mapping_no_p_mnl_require_nonlinear_pk(): - """Test the NumCosmo mapping connector with a model without p_mnl but - with require_nonlinear_pk=True.""" - - cosmo = Nc.HICosmoDECpl() - prim = Nc.HIPrimPowerLaw() - cosmo.add_submodel(prim) - - map_cosmo = MappingNumCosmo( - require_nonlinear_pk=True, - dist=Nc.Distance.new(6.0), - ) - - mset = Ncm.MSet() - mset.set(cosmo) - - map_cosmo.set_params_from_numcosmo(mset) - - ccl_args = map_cosmo.calculate_ccl_args(mset) - - assert ccl_args["nonlinear_model"] == "halofit" - - -def test_numcosmo_mapping_no_p_mnl(): - """Test the NumCosmo mapping connector with a model without p_mnl and - require_nonlinear_pk=False.""" - - cosmo = Nc.HICosmoDECpl() - prim = Nc.HIPrimPowerLaw() - cosmo.add_submodel(prim) - - map_cosmo = MappingNumCosmo( - require_nonlinear_pk=False, - dist=Nc.Distance.new(6.0), - ) - - mset = Ncm.MSet() - mset.set(cosmo) - - map_cosmo.set_params_from_numcosmo(mset) - - ccl_args = map_cosmo.calculate_ccl_args(mset) - - assert ccl_args["nonlinear_model"] is None + map_cosmo.set_params_from_numcosmo( + mset, CCLFactory(amplitude_parameter=PoweSpecAmplitudeParameter.AS) + ) @pytest.mark.parametrize( @@ -261,7 +196,6 @@ def test_numcosmo_mapping(numcosmo_cosmo_fixture, request): cosmo = numcosmo_cosmo["cosmo"] map_cosmo = MappingNumCosmo( - require_nonlinear_pk=True, p_ml=numcosmo_cosmo["p_ml"], p_mnl=numcosmo_cosmo["p_mnl"], dist=numcosmo_cosmo["dist"], @@ -270,7 +204,7 @@ def test_numcosmo_mapping(numcosmo_cosmo_fixture, request): mset = Ncm.MSet() mset.set(cosmo) - map_cosmo.set_params_from_numcosmo(mset) + map_cosmo.set_params_from_numcosmo(mset, CCLFactory()) ccl_args = map_cosmo.calculate_ccl_args(mset) ccl_cosmo = ccl.CosmologyCalculator(**map_cosmo.mapping.asdict(), **ccl_args) @@ -290,7 +224,6 @@ def test_numcosmo_serialize_mapping(numcosmo_cosmo_fixture, request): cosmo = numcosmo_cosmo["cosmo"] map_cosmo = MappingNumCosmo( - require_nonlinear_pk=True, p_ml=numcosmo_cosmo["p_ml"], p_mnl=numcosmo_cosmo["p_mnl"], dist=numcosmo_cosmo["dist"], @@ -307,8 +240,8 @@ def test_numcosmo_serialize_mapping(numcosmo_cosmo_fixture, request): mset = Ncm.MSet() mset.set(cosmo) - map_cosmo.set_params_from_numcosmo(mset) - map_cosmo_dup.set_params_from_numcosmo(mset) + map_cosmo.set_params_from_numcosmo(mset, CCLFactory()) + map_cosmo_dup.set_params_from_numcosmo(mset, CCLFactory()) if map_cosmo_dup.p_ml is None: assert map_cosmo_dup.p_ml is None @@ -343,7 +276,6 @@ def test_numcosmo_data( cosmo = numcosmo_cosmo["cosmo"] map_cosmo = MappingNumCosmo( - require_nonlinear_pk=True, p_ml=numcosmo_cosmo["p_ml"], p_mnl=numcosmo_cosmo["p_mnl"], dist=numcosmo_cosmo["dist"], @@ -395,7 +327,6 @@ def test_numcosmo_gauss_cov( cosmo = numcosmo_cosmo["cosmo"] map_cosmo = MappingNumCosmo( - require_nonlinear_pk=True, p_ml=numcosmo_cosmo["p_ml"], p_mnl=numcosmo_cosmo["p_mnl"], dist=numcosmo_cosmo["dist"], @@ -451,7 +382,6 @@ def test_numcosmo_serialize_likelihood( numcosmo_cosmo = request.getfixturevalue(numcosmo_cosmo_fixture) map_cosmo = MappingNumCosmo( - require_nonlinear_pk=True, p_ml=numcosmo_cosmo["p_ml"], p_mnl=numcosmo_cosmo["p_mnl"], dist=numcosmo_cosmo["dist"], @@ -490,3 +420,27 @@ def test_numcosmo_serialize_likelihood( assert id(fc_data_dup.model_list) != id(fc_data.model_list) assert isinstance(fc_data_dup.model_list, list) assert fc_data_dup.model_list == fc_data.model_list + + +def test_numcosmo_mapping_deprecated_require_nonlinear_pk(): + """Test the MappingNumCosmo deprecated require_nonlinear_pk.""" + + with pytest.deprecated_call(): + _ = MappingNumCosmo(require_nonlinear_pk=True) + + +def test_numcosmo_mapping_sigma8_missing_pk(): + """Test the MappingNumCosmo with sigma8 as a parameter but missing pk.""" + + cosmo = Nc.HICosmoDEXcdm() + + map_cosmo = MappingNumCosmo(dist=Nc.Distance.new(6.0)) + mset = Ncm.MSet() + mset.set(cosmo) + + with pytest.raises( + ValueError, match="PowspecML object must be provided when using sigma8." + ): + map_cosmo.set_params_from_numcosmo( + mset, CCLFactory(amplitude_parameter=PoweSpecAmplitudeParameter.SIGMA8) + ) diff --git a/tests/connector/test_mapping.py b/tests/connector/test_mapping.py index 3fef8693..27824c2b 100644 --- a/tests/connector/test_mapping.py +++ b/tests/connector/test_mapping.py @@ -53,7 +53,6 @@ def test_conversion_from_cosmosis_camb(): assert p.Omega_g is None assert p.Neff == pytest.approx(3.046) assert p.m_nu == pytest.approx(0.3015443336635814) - assert p.m_nu_type == "normal" # Currently the only option assert p.w0 == cosmosis_params["w"] assert p.wa == cosmosis_params["wa"] assert p.T_CMB == 2.7255 # currently hard-wired @@ -135,7 +134,6 @@ def test_sigma8_and_A_s(): "Omega_k": 0.0, "Neff": 3.046, "m_nu": 0.0, - "m_nu_type": "normal", "w0": -1.0, "wa": 0.0, "T_CMB": 2.7255, diff --git a/tests/integration/test_des_y1_3x2pt.py b/tests/integration/test_des_y1_3x2pt.py index 75c3a7a8..99fee1ab 100644 --- a/tests/integration/test_des_y1_3x2pt.py +++ b/tests/integration/test_des_y1_3x2pt.py @@ -15,6 +15,7 @@ def test_des_y1_3x2pt_cosmosis(): cd examples/des_y1_3x2pt cosmosis des_y1_3x2pt.ini cosmosis des_y1_3x2pt_PT.ini + cosmosis des_y1_3x2pt_default_factory.ini """, ], capture_output=True, diff --git a/tests/likelihood/gauss_family/lkscript_const_gaussian.py b/tests/likelihood/gauss_family/lkscript_const_gaussian.py index c58be271..e8a37893 100644 --- a/tests/likelihood/gauss_family/lkscript_const_gaussian.py +++ b/tests/likelihood/gauss_family/lkscript_const_gaussian.py @@ -7,6 +7,8 @@ from firecrown.likelihood.gaussian import ConstGaussian from firecrown.likelihood.supernova import Supernova +from firecrown.modeling_tools import ModelingTools +from firecrown.ccl_factory import CCLFactory, PoweSpecAmplitudeParameter def build_likelihood(_): @@ -24,4 +26,9 @@ def build_likelihood(_): sacc_data.add_data_point("supernova_distance_mu", tracer_tuple, -3.0, z=0.3) sacc_data.add_covariance([4.0, 9.0, 16.0]) likelihood.read(sacc_data) - return likelihood + + tools = ModelingTools( + ccl_factory=CCLFactory(amplitude_parameter=PoweSpecAmplitudeParameter.SIGMA8) + ) + + return likelihood, tools diff --git a/tests/likelihood/gauss_family/lkscript_two_point.py b/tests/likelihood/gauss_family/lkscript_two_point.py index b8ec49e3..fff9029d 100644 --- a/tests/likelihood/gauss_family/lkscript_two_point.py +++ b/tests/likelihood/gauss_family/lkscript_two_point.py @@ -12,6 +12,8 @@ from firecrown.likelihood.number_counts import ( NumberCounts, ) +from firecrown.modeling_tools import ModelingTools +from firecrown.ccl_factory import CCLFactory, PoweSpecAmplitudeParameter def build_likelihood(params: NamedParameters): @@ -42,4 +44,9 @@ def build_likelihood(params: NamedParameters): likelihood = ConstGaussian(statistics=[two_point]) likelihood.read(sacc_data) - return likelihood + + tools = ModelingTools( + ccl_factory=CCLFactory(amplitude_parameter=PoweSpecAmplitudeParameter.SIGMA8) + ) + + return likelihood, tools diff --git a/tests/likelihood/gauss_family/statistic/test_binned_cluster_number_counts.py b/tests/likelihood/gauss_family/statistic/test_binned_cluster_number_counts.py index 9597320e..58e5f2dc 100644 --- a/tests/likelihood/gauss_family/statistic/test_binned_cluster_number_counts.py +++ b/tests/likelihood/gauss_family/statistic/test_binned_cluster_number_counts.py @@ -4,18 +4,18 @@ import sacc import pytest import pyccl +from firecrown.updatable import get_default_params_map from firecrown.models.cluster.recipes.cluster_recipe import ClusterRecipe from firecrown.likelihood.source import SourceSystematic from firecrown.modeling_tools import ModelingTools from firecrown.models.cluster.properties import ClusterProperty -from firecrown.parameters import ParamsMap from firecrown.models.cluster.abundance import ClusterAbundance from firecrown.likelihood.binned_cluster_number_counts import ( BinnedClusterNumberCounts, ) -def test_create_binned_number_counts(): +def test_create_binned_number_counts() -> None: recipe = Mock(spec=ClusterRecipe) bnc = BinnedClusterNumberCounts(ClusterProperty.NONE, "Test", recipe) assert bnc is not None @@ -36,7 +36,7 @@ def test_create_binned_number_counts(): assert bnc.systematics == systematics -def test_get_data_vector(): +def test_get_data_vector() -> None: recipe = Mock(spec=ClusterRecipe) bnc = BinnedClusterNumberCounts(ClusterProperty.NONE, "Test", recipe) dv = bnc.get_data_vector() @@ -44,7 +44,7 @@ def test_get_data_vector(): assert len(dv) == 0 -def test_read_throws_if_no_property(cluster_sacc_data: sacc.Sacc): +def test_read_throws_if_no_property(cluster_sacc_data: sacc.Sacc) -> None: recipe = Mock(spec=ClusterRecipe) bnc = BinnedClusterNumberCounts(ClusterProperty.NONE, "my_survey", recipe) @@ -55,7 +55,7 @@ def test_read_throws_if_no_property(cluster_sacc_data: sacc.Sacc): bnc.read(cluster_sacc_data) -def test_read_single_property(cluster_sacc_data: sacc.Sacc): +def test_read_single_property(cluster_sacc_data: sacc.Sacc) -> None: recipe = Mock(spec=ClusterRecipe) bnc = BinnedClusterNumberCounts(ClusterProperty.COUNTS, "my_survey", recipe) @@ -75,7 +75,7 @@ def test_read_single_property(cluster_sacc_data: sacc.Sacc): assert len(bnc.sacc_indices) == 2 -def test_read_multiple_properties(cluster_sacc_data: sacc.Sacc): +def test_read_multiple_properties(cluster_sacc_data: sacc.Sacc) -> None: recipe = Mock(spec=ClusterRecipe) bnc = BinnedClusterNumberCounts( (ClusterProperty.COUNTS | ClusterProperty.MASS), "my_survey", recipe @@ -88,18 +88,16 @@ def test_read_multiple_properties(cluster_sacc_data: sacc.Sacc): assert len(bnc.sacc_indices) == 4 -def test_compute_theory_vector(cluster_sacc_data: sacc.Sacc): +def test_compute_theory_vector(cluster_sacc_data: sacc.Sacc) -> None: recipe = Mock(spec=ClusterRecipe) recipe.evaluate_theory_prediction.return_value = 1.0 tools = ModelingTools() hmf = pyccl.halos.MassFuncBocquet16() - cosmo = pyccl.cosmology.CosmologyVanillaLCDM() - params = ParamsMap() - tools.cluster_abundance = ClusterAbundance(13, 17, 0, 2, hmf) + params = get_default_params_map(tools) tools.update(params) - tools.prepare(cosmo) + tools.prepare() bnc = BinnedClusterNumberCounts(ClusterProperty.COUNTS, "my_survey", recipe) bnc.read(cluster_sacc_data) diff --git a/tests/likelihood/gauss_family/statistic/test_supernova.py b/tests/likelihood/gauss_family/statistic/test_supernova.py index 4a433bfc..c3932b73 100644 --- a/tests/likelihood/gauss_family/statistic/test_supernova.py +++ b/tests/likelihood/gauss_family/statistic/test_supernova.py @@ -4,11 +4,10 @@ import pytest import sacc -import pyccl +from firecrown.updatable import get_default_params_map from firecrown.likelihood.supernova import Supernova from firecrown.modeling_tools import ModelingTools -from firecrown.parameters import ParamsMap @pytest.fixture(name="minimal_stat") @@ -48,7 +47,7 @@ def fixture_sacc_data() -> sacc.Sacc: def test_missing_sacc_tracer_fails_read( minimal_stat: Supernova, missing_sacc_tracer: sacc.Sacc -): +) -> None: with pytest.raises( ValueError, match="The SACC file does not contain the MiscTracer sn_fake_sample", @@ -58,7 +57,7 @@ def test_missing_sacc_tracer_fails_read( def test_wrong_tracer_type_fails_read( minimal_stat: Supernova, wrong_tracer_type: sacc.Sacc -): +) -> None: with pytest.raises( ValueError, match=f"The SACC tracer {minimal_stat.sacc_tracer} is not a MiscTracer", @@ -66,7 +65,7 @@ def test_wrong_tracer_type_fails_read( minimal_stat.read(wrong_tracer_type) -def test_read_works(minimal_stat: Supernova, good_sacc_data: sacc.Sacc): +def test_read_works(minimal_stat: Supernova, good_sacc_data: sacc.Sacc) -> None: """After read() is called, we should be able to get the statistic's :class:`DataVector` and also should be able to call @@ -78,9 +77,10 @@ def test_read_works(minimal_stat: Supernova, good_sacc_data: sacc.Sacc): assert data_vector[0] == 16.95 tools = ModelingTools() - tools.update(ParamsMap()) - tools.prepare(pyccl.CosmologyVanillaLCDM()) - params = ParamsMap({"sn_fake_sample_M": 1.1}) + params = get_default_params_map(tools) + params.update({"sn_fake_sample_M": 1.1}) + tools.update(params) + tools.prepare() minimal_stat.update(params) theory_vector = minimal_stat.compute_theory_vector(tools) assert len(theory_vector) == 1 diff --git a/tests/likelihood/gauss_family/statistic/test_two_point.py b/tests/likelihood/gauss_family/statistic/test_two_point.py index 599769c3..d90626b1 100644 --- a/tests/likelihood/gauss_family/statistic/test_two_point.py +++ b/tests/likelihood/gauss_family/statistic/test_two_point.py @@ -8,8 +8,7 @@ from numpy.testing import assert_allclose import pytest -import pyccl - +from firecrown.updatable import get_default_params_map from firecrown.modeling_tools import ModelingTools from firecrown.parameters import ParamsMap @@ -194,20 +193,21 @@ def test_two_point_src0_src0_window(sacc_galaxy_cells_src0_src0_window) -> None: statistic.read(sacc_data) tools = ModelingTools() - tools.update(ParamsMap()) - tools.prepare(pyccl.CosmologyVanillaLCDM()) + params = get_default_params_map(tools) + tools.update(params) + tools.prepare() assert statistic.window is not None statistic.reset() - statistic.update(ParamsMap()) - tools.update(ParamsMap()) + statistic.update(params) + tools.update(params) result1 = statistic.compute_theory_vector(tools) assert all(np.isfinite(result1)) statistic.reset() - statistic.update(ParamsMap()) - tools.update(ParamsMap()) + statistic.update(params) + tools.update(params) result2 = statistic.compute_theory_vector(tools) assert np.array_equal(result1, result2) @@ -223,20 +223,21 @@ def test_two_point_src0_src0_no_window(sacc_galaxy_cells_src0_src0_no_window) -> statistic.read(sacc_data) tools = ModelingTools() - tools.update(ParamsMap()) - tools.prepare(pyccl.CosmologyVanillaLCDM()) + params = get_default_params_map(tools) + tools.update(params) + tools.prepare() assert statistic.window is None statistic.reset() - statistic.update(ParamsMap()) - tools.update(ParamsMap()) + statistic.update(params) + tools.update(params) result1 = statistic.compute_theory_vector(tools) assert all(np.isfinite(result1)) statistic.reset() - statistic.update(ParamsMap()) - tools.update(ParamsMap()) + statistic.update(params) + tools.update(params) result2 = statistic.compute_theory_vector(tools) assert np.array_equal(result1, result2) @@ -342,8 +343,9 @@ def test_two_point_src0_src0_cuts(sacc_galaxy_cells_src0_src0) -> None: statistic.read(sacc_data) tools = ModelingTools() - tools.update(ParamsMap()) - tools.prepare(pyccl.CosmologyVanillaLCDM()) + params = get_default_params_map(tools) + tools.update(params) + tools.prepare() assert statistic.window is None assert statistic.ells is not None @@ -351,7 +353,7 @@ def test_two_point_src0_src0_cuts(sacc_galaxy_cells_src0_src0) -> None: assert all(statistic.ells >= 50) assert all(statistic.ells <= 200) - statistic.update(ParamsMap()) + statistic.update(params) statistic.compute_theory_vector(tools) @@ -365,10 +367,11 @@ def test_two_point_lens0_lens0_cuts(sacc_galaxy_xis_lens0_lens0) -> None: ) statistic.read(sacc_data) - param_map = ParamsMap({"lens0_bias": 1.0}) tools = ModelingTools() - tools.update(param_map) - tools.prepare(pyccl.CosmologyVanillaLCDM()) + params = get_default_params_map(tools) + params.update({"lens0_bias": 1.0}) + tools.update(params) + tools.prepare() assert statistic.window is None assert statistic.thetas is not None @@ -376,7 +379,7 @@ def test_two_point_lens0_lens0_cuts(sacc_galaxy_xis_lens0_lens0) -> None: assert all(statistic.thetas >= 0.1) assert all(statistic.thetas <= 0.5) - statistic.update(param_map) + statistic.update(params) statistic.compute_theory_vector(tools) @@ -393,10 +396,11 @@ def test_two_point_lens0_lens0_config(sacc_galaxy_xis_lens0_lens0) -> None: ) statistic.read(sacc_data) - param_map = ParamsMap({"lens0_bias": 1.0}) tools = ModelingTools() - tools.update(param_map) - tools.prepare(pyccl.CosmologyVanillaLCDM()) + params = get_default_params_map(tools) + params.update({"lens0_bias": 1.0}) + tools.update(params) + tools.prepare() assert statistic.window is None assert statistic.thetas is not None @@ -407,7 +411,7 @@ def test_two_point_lens0_lens0_config(sacc_galaxy_xis_lens0_lens0) -> None: # on how many unique ells we get from the log-binning. assert len(statistic.ells_for_xi) == 175 - statistic.update(param_map) + statistic.update(params) statistic.compute_theory_vector(tools) @@ -619,13 +623,15 @@ def test_from_measurement_compute_theory_vector_window( assert isinstance(two_point_with_window, TwoPoint) assert two_point_with_window.ready - req_params = two_point_with_window.required_parameters() + tools = ModelingTools() + req_params = ( + two_point_with_window.required_parameters() + tools.required_parameters() + ) default_values = req_params.get_default_values() params = ParamsMap(default_values) - tools = ModelingTools() tools.update(params) - tools.prepare(pyccl.CosmologyVanillaLCDM()) + tools.prepare() two_point_with_window.update(params) prediction = two_point_with_window.compute_theory_vector(tools) @@ -643,14 +649,15 @@ def test_from_measurement_compute_theory_vector_window_check( assert isinstance(two_point_without_window, TwoPoint) assert two_point_without_window.ready - req_params = two_point_with_window.required_parameters() + tools = ModelingTools() + req_params = ( + two_point_with_window.required_parameters() + tools.required_parameters() + ) default_values = req_params.get_default_values() params = ParamsMap(default_values) - tools = ModelingTools() tools.update(params) - tools.prepare(pyccl.CosmologyVanillaLCDM()) - + tools.prepare() two_point_with_window.update(params) two_point_without_window.update(params) diff --git a/tests/likelihood/lkdir/lk_derived_parameter.py b/tests/likelihood/lkdir/lk_derived_parameter.py index 9b490f89..78d90e20 100644 --- a/tests/likelihood/lkdir/lk_derived_parameter.py +++ b/tests/likelihood/lkdir/lk_derived_parameter.py @@ -4,9 +4,14 @@ """ from firecrown.likelihood.likelihood import NamedParameters +from firecrown.modeling_tools import ModelingTools +from firecrown.ccl_factory import CCLFactory, PoweSpecAmplitudeParameter from . import lkmodule def build_likelihood(_: NamedParameters): """Return a DerivedParameterLikelihood object.""" - return lkmodule.derived_parameter_likelihood() + tools = ModelingTools( + ccl_factory=CCLFactory(amplitude_parameter=PoweSpecAmplitudeParameter.SIGMA8) + ) + return lkmodule.derived_parameter_likelihood(), tools diff --git a/tests/likelihood/lkdir/lk_needing_param.py b/tests/likelihood/lkdir/lk_needing_param.py index e228cfdd..28e9f745 100644 --- a/tests/likelihood/lkdir/lk_needing_param.py +++ b/tests/likelihood/lkdir/lk_needing_param.py @@ -4,9 +4,15 @@ """ from firecrown.likelihood.likelihood import NamedParameters +from firecrown.modeling_tools import ModelingTools +from firecrown.ccl_factory import CCLFactory, PoweSpecAmplitudeParameter from . import lkmodule def build_likelihood(params: NamedParameters): """Return a ParameterizedLikelihood object.""" - return lkmodule.parameterized_likelihood(params) + tools = ModelingTools( + ccl_factory=CCLFactory(amplitude_parameter=PoweSpecAmplitudeParameter.SIGMA8) + ) + + return lkmodule.parameterized_likelihood(params), tools diff --git a/tests/likelihood/lkdir/lk_sampler_parameter.py b/tests/likelihood/lkdir/lk_sampler_parameter.py index 7cf9a2d3..f89bba05 100644 --- a/tests/likelihood/lkdir/lk_sampler_parameter.py +++ b/tests/likelihood/lkdir/lk_sampler_parameter.py @@ -5,9 +5,14 @@ """ from firecrown.likelihood.likelihood import NamedParameters +from firecrown.modeling_tools import ModelingTools +from firecrown.ccl_factory import CCLFactory, PoweSpecAmplitudeParameter from . import lkmodule def build_likelihood(params: NamedParameters): """Return a SamplerParameterLikelihood object.""" - return lkmodule.sampler_parameter_likelihood(params) + tools = ModelingTools( + ccl_factory=CCLFactory(amplitude_parameter=PoweSpecAmplitudeParameter.SIGMA8) + ) + return lkmodule.sampler_parameter_likelihood(params), tools diff --git a/tests/likelihood/lkdir/lkscript.py b/tests/likelihood/lkdir/lkscript.py index 011abc77..7478afa3 100644 --- a/tests/likelihood/lkdir/lkscript.py +++ b/tests/likelihood/lkdir/lkscript.py @@ -3,11 +3,14 @@ """ from firecrown.modeling_tools import ModelingTools +from firecrown.ccl_factory import CCLFactory, PoweSpecAmplitudeParameter from . import lkmodule def build_likelihood(_): """Return an EmptyLikelihood object.""" - tools = ModelingTools() + tools = ModelingTools( + ccl_factory=CCLFactory(amplitude_parameter=PoweSpecAmplitudeParameter.SIGMA8) + ) tools.test_attribute = "test" # type: ignore[unused-ignore] return (lkmodule.empty_likelihood(), tools) diff --git a/tests/likelihood/test_factories.py b/tests/likelihood/test_factories.py new file mode 100644 index 00000000..45c15da1 --- /dev/null +++ b/tests/likelihood/test_factories.py @@ -0,0 +1,315 @@ +"""Tests for the module firecrown.likelihood.factories. +""" + +import re +from pathlib import Path +import pytest + +import sacc + +from firecrown.likelihood.factories import ( + build_two_point_likelihood, + DataSourceSacc, + TwoPointCorrelationSpace, + TwoPointExperiment, + TwoPointFactory, +) +from firecrown.likelihood.weak_lensing import WeakLensingFactory +from firecrown.likelihood.number_counts import NumberCountsFactory +from firecrown.likelihood.likelihood import Likelihood, NamedParameters +from firecrown.modeling_tools import ModelingTools + + +def test_two_point_factory_dict() -> None: + two_point_factory_dict = { + "correlation_space": TwoPointCorrelationSpace.HARMONIC, + "weak_lensing_factory": {"per_bin_systematics": [], "global_systematics": []}, + "number_counts_factory": {"per_bin_systematics": [], "global_systematics": []}, + } + two_point_factory = TwoPointFactory.model_validate(two_point_factory_dict) + assert isinstance(two_point_factory, TwoPointFactory) + assert isinstance(two_point_factory.weak_lensing_factory, WeakLensingFactory) + assert isinstance(two_point_factory.number_counts_factory, NumberCountsFactory) + assert two_point_factory.correlation_space == TwoPointCorrelationSpace.HARMONIC + assert two_point_factory.weak_lensing_factory.per_bin_systematics == [] + assert two_point_factory.weak_lensing_factory.global_systematics == [] + assert two_point_factory.number_counts_factory.per_bin_systematics == [] + assert two_point_factory.number_counts_factory.global_systematics == [] + + +def test_two_point_factor_direct() -> None: + two_point_factory = TwoPointFactory( + correlation_space=TwoPointCorrelationSpace.HARMONIC, + weak_lensing_factory=WeakLensingFactory( + per_bin_systematics=[], global_systematics=[] + ), + number_counts_factory=NumberCountsFactory( + per_bin_systematics=[], global_systematics=[] + ), + ) + + assert two_point_factory.correlation_space == TwoPointCorrelationSpace.HARMONIC + assert two_point_factory.weak_lensing_factory.per_bin_systematics == [] + assert two_point_factory.weak_lensing_factory.global_systematics == [] + assert two_point_factory.number_counts_factory.per_bin_systematics == [] + assert two_point_factory.number_counts_factory.global_systematics == [] + + +def test_two_point_factor_invalid_correlation_space_type() -> None: + two_point_factory_dict = { + "correlation_space": 1.2, + "weak_lensing_factory": {"per_bin_systematics": [], "global_systematics": []}, + "number_counts_factory": {"per_bin_systematics": [], "global_systematics": []}, + } + with pytest.raises( + ValueError, + match=re.compile( + ".*validation error for TwoPointFactory.*correlation_space.*", re.DOTALL + ), + ): + TwoPointFactory.model_validate(two_point_factory_dict) + + +def test_two_point_factor_invalid_correlation_space_option() -> None: + two_point_factory_dict = { + "correlation_space": "invalid", + "weak_lensing_factory": {"per_bin_systematics": [], "global_systematics": []}, + "number_counts_factory": {"per_bin_systematics": [], "global_systematics": []}, + } + with pytest.raises( + ValueError, + match=re.compile( + ".*validation error for TwoPointFactory.*correlation_space.*", re.DOTALL + ), + ): + TwoPointFactory.model_validate(two_point_factory_dict) + + +def test_data_source_sacc_dict() -> None: + data_source_sacc_dict = {"sacc_data_file": "tests/bug_398.sacc.gz"} + data_source_sacc = DataSourceSacc.model_validate(data_source_sacc_dict) + assert isinstance(data_source_sacc, DataSourceSacc) + assert data_source_sacc.sacc_data_file == "tests/bug_398.sacc.gz" + + +def test_data_source_sacc_direct() -> None: + data_source_sacc = DataSourceSacc(sacc_data_file="tests/bug_398.sacc.gz") + assert data_source_sacc.sacc_data_file == "tests/bug_398.sacc.gz" + + +def test_data_source_sacc_invalid_file() -> None: + data_source_sacc_dict = {"sacc_data_file": "tests/file_not_found.sacc.gz"} + with pytest.raises( + FileNotFoundError, match=".*File tests/file_not_found.sacc.gz does not exist.*" + ): + DataSourceSacc.model_validate(data_source_sacc_dict) + + +def test_data_source_sacc_get_sacc_data() -> None: + data_source_sacc_dict = {"sacc_data_file": "tests/bug_398.sacc.gz"} + data_source_sacc = DataSourceSacc.model_validate(data_source_sacc_dict) + assert isinstance(data_source_sacc, DataSourceSacc) + assert data_source_sacc.sacc_data_file == "tests/bug_398.sacc.gz" + sacc_data = data_source_sacc.get_sacc_data() + assert sacc_data is not None + assert isinstance(sacc_data, sacc.Sacc) + + +def test_two_point_experiment_dict() -> None: + two_point_experiment_dict = { + "two_point_factory": { + "correlation_space": TwoPointCorrelationSpace.HARMONIC, + "weak_lensing_factory": { + "per_bin_systematics": [], + "global_systematics": [], + }, + "number_counts_factory": { + "per_bin_systematics": [], + "global_systematics": [], + }, + }, + "data_source": {"sacc_data_file": "tests/bug_398.sacc.gz"}, + } + two_point_experiment = TwoPointExperiment.model_validate(two_point_experiment_dict) + assert isinstance(two_point_experiment, TwoPointExperiment) + assert isinstance(two_point_experiment.two_point_factory, TwoPointFactory) + assert isinstance(two_point_experiment.data_source, DataSourceSacc) + assert ( + two_point_experiment.two_point_factory.correlation_space + == TwoPointCorrelationSpace.HARMONIC + ) + assert ( + two_point_experiment.two_point_factory.weak_lensing_factory.per_bin_systematics + == [] + ) + assert ( + two_point_experiment.two_point_factory.weak_lensing_factory.global_systematics + == [] + ) + assert ( + two_point_experiment.two_point_factory.number_counts_factory.per_bin_systematics + == [] + ) + assert ( + two_point_experiment.two_point_factory.number_counts_factory.global_systematics + == [] + ) + assert two_point_experiment.data_source.sacc_data_file == "tests/bug_398.sacc.gz" + + +def test_two_point_experiment_direct() -> None: + two_point_experiment = TwoPointExperiment( + two_point_factory=TwoPointFactory( + correlation_space=TwoPointCorrelationSpace.HARMONIC, + weak_lensing_factory=WeakLensingFactory( + per_bin_systematics=[], global_systematics=[] + ), + number_counts_factory=NumberCountsFactory( + per_bin_systematics=[], global_systematics=[] + ), + ), + data_source=DataSourceSacc(sacc_data_file="tests/bug_398.sacc.gz"), + ) + assert isinstance(two_point_experiment, TwoPointExperiment) + assert isinstance(two_point_experiment.two_point_factory, TwoPointFactory) + assert isinstance(two_point_experiment.data_source, DataSourceSacc) + assert ( + two_point_experiment.two_point_factory.correlation_space + == TwoPointCorrelationSpace.HARMONIC + ) + assert ( + two_point_experiment.two_point_factory.weak_lensing_factory.per_bin_systematics + == [] + ) + assert ( + two_point_experiment.two_point_factory.weak_lensing_factory.global_systematics + == [] + ) + assert ( + two_point_experiment.two_point_factory.number_counts_factory.per_bin_systematics + == [] + ) + assert ( + two_point_experiment.two_point_factory.number_counts_factory.global_systematics + == [] + ) + assert two_point_experiment.data_source.sacc_data_file == "tests/bug_398.sacc.gz" + + +def test_build_two_point_likelihood_real(tmp_path: Path) -> None: + tmp_experiment_file = tmp_path / "experiment.yaml" + tmp_experiment_file.write_text( + """ +two_point_factory: + correlation_space: real + weak_lensing_factory: + per_bin_systematics: [] + global_systematics: [] + number_counts_factory: + per_bin_systematics: [] + global_systematics: [] +data_source: + sacc_data_file: examples/des_y1_3x2pt/des_y1_3x2pt_sacc_data.fits +""" + ) + + build_parameters = NamedParameters({"likelihood_config": str(tmp_experiment_file)}) + likelihood, tools = build_two_point_likelihood(build_parameters) + assert isinstance(likelihood, Likelihood) + assert isinstance(tools, ModelingTools) + + +def test_build_two_point_likelihood_harmonic(tmp_path: Path) -> None: + tmp_experiment_file = tmp_path / "experiment.yaml" + tmp_experiment_file.write_text( + """ +two_point_factory: + correlation_space: harmonic + weak_lensing_factory: + per_bin_systematics: [] + global_systematics: [] + number_counts_factory: + per_bin_systematics: [] + global_systematics: [] +data_source: + sacc_data_file: tests/bug_398.sacc.gz +""" + ) + + build_parameters = NamedParameters({"likelihood_config": str(tmp_experiment_file)}) + likelihood, tools = build_two_point_likelihood(build_parameters) + assert isinstance(likelihood, Likelihood) + assert isinstance(tools, ModelingTools) + + +def test_build_two_point_likelihood_real_no_real_data(tmp_path: Path) -> None: + tmp_experiment_file = tmp_path / "experiment.yaml" + tmp_experiment_file.write_text( + """ +two_point_factory: + correlation_space: real + weak_lensing_factory: + per_bin_systematics: [] + global_systematics: [] + number_counts_factory: + per_bin_systematics: [] + global_systematics: [] +data_source: + sacc_data_file: tests/bug_398.sacc.gz +""" + ) + + build_parameters = NamedParameters({"likelihood_config": str(tmp_experiment_file)}) + with pytest.raises( + ValueError, + match=re.compile( + "No two-point measurements in real space found in the SACC file.", re.DOTALL + ), + ): + _ = build_two_point_likelihood(build_parameters) + + +def test_build_two_point_likelihood_harmonic_no_harmonic_data(tmp_path: Path) -> None: + tmp_experiment_file = tmp_path / "experiment.yaml" + tmp_experiment_file.write_text( + """ +two_point_factory: + correlation_space: harmonic + weak_lensing_factory: + per_bin_systematics: [] + global_systematics: [] + number_counts_factory: + per_bin_systematics: [] + global_systematics: [] +data_source: + sacc_data_file: examples/des_y1_3x2pt/des_y1_3x2pt_sacc_data.fits +""" + ) + + build_parameters = NamedParameters({"likelihood_config": str(tmp_experiment_file)}) + with pytest.raises( + ValueError, + match=re.compile( + "No two-point measurements in harmonic space found in the SACC file.", + re.DOTALL, + ), + ): + _ = build_two_point_likelihood(build_parameters) + + +def test_build_two_point_likelihood_empty_likelihood_config(tmp_path: Path) -> None: + tmp_experiment_file = tmp_path / "experiment.yaml" + tmp_experiment_file.write_text("") + + build_parameters = NamedParameters({"likelihood_config": str(tmp_experiment_file)}) + with pytest.raises(ValueError, match="No likelihood config found."): + _ = build_two_point_likelihood(build_parameters) + + +def test_build_two_point_likelihood_invalid_likelihood_config(tmp_path: Path) -> None: + tmp_experiment_file = tmp_path / "experiment.yaml" + tmp_experiment_file.write_text("I'm not a valid YAML file.") + + build_parameters = NamedParameters({"likelihood_config": str(tmp_experiment_file)}) + with pytest.raises(ValueError, match=".*validation error for TwoPointExperiment.*"): + _ = build_two_point_likelihood(build_parameters) diff --git a/tests/likelihood/test_likelihood.py b/tests/likelihood/test_likelihood.py index f811d606..92e10291 100644 --- a/tests/likelihood/test_likelihood.py +++ b/tests/likelihood/test_likelihood.py @@ -124,7 +124,9 @@ def test_load_likelihood_from_module(): def test_load_likelihood_from_module_not_in_path(): with pytest.raises( ValueError, - match="Unrecognized Firecrown initialization module lkscript_not_in_path.", + match="Unrecognized Firecrown initialization module 'lkscript_not_in_path'. " + "The module must be either a module_name or a module_name.func where func " + "is the factory function.", ): _ = load_likelihood_from_module("lkscript_not_in_path", NamedParameters()) @@ -136,3 +138,16 @@ def test_load_likelihood_not_in_path(): "lkscript_not_in_path.", ): _ = load_likelihood("lkscript_not_in_path", NamedParameters()) + + +def test_load_likelihood_from_module_function_missing_likelihood_config(): + dir_path = os.path.dirname(os.path.realpath(__file__)) + module_path = os.path.join(dir_path, "lkdir") + + sys.path.append(module_path) + + with pytest.raises(KeyError, match="likelihood_config"): + _ = load_likelihood_from_module( + "firecrown.likelihood.factories.build_two_point_likelihood", + NamedParameters(), + ) diff --git a/tests/test_bug398.py b/tests/test_bug398.py index bf99b6fb..266abbcc 100644 --- a/tests/test_bug398.py +++ b/tests/test_bug398.py @@ -2,16 +2,16 @@ import os import sacc -import pyccl from numpy.testing import assert_allclose +from firecrown.updatable import get_default_params_map import firecrown.likelihood.weak_lensing as wl from firecrown.likelihood.two_point import TwoPoint from firecrown.likelihood.gaussian import ConstGaussian from firecrown.modeling_tools import ModelingTools from firecrown.likelihood.likelihood import Likelihood, NamedParameters -from firecrown.parameters import ParamsMap +from firecrown.ccl_factory import CCLFactory, PoweSpecAmplitudeParameter def build_likelihood( @@ -42,7 +42,9 @@ def build_likelihood( sacc_data_type="galaxy_density_cl", ) - modeling_tools = ModelingTools() + modeling_tools = ModelingTools( + ccl_factory=CCLFactory(amplitude_parameter=PoweSpecAmplitudeParameter.SIGMA8) + ) likelihood = ConstGaussian(statistics=[src2_src2, lens0_lens0, lens0_src2]) likelihood.read(sacc_data) @@ -119,7 +121,7 @@ def build_likelihood( ] -def test_broken_window_function(): +def test_broken_window_function() -> None: likelihood, modeling_tools = build_likelihood( build_parameters=NamedParameters({"sacc_data": SACC_FILE}) ) @@ -127,13 +129,13 @@ def test_broken_window_function(): assert modeling_tools is not None -def test_eval_cl_window_src2_src2(): - tools = ModelingTools() - cosmo = pyccl.CosmologyVanillaLCDM() - params = ParamsMap() - +def test_eval_cl_window_src2_src2() -> None: + tools = ModelingTools( + ccl_factory=CCLFactory(amplitude_parameter=PoweSpecAmplitudeParameter.SIGMA8) + ) + params = get_default_params_map(tools) tools.update(params) - tools.prepare(cosmo) + tools.prepare() sacc_data = sacc.Sacc.load_fits(SACC_FILE) src2 = wl.WeakLensing(sacc_tracer="src2") @@ -151,13 +153,13 @@ def test_eval_cl_window_src2_src2(): assert_allclose(theory_vector, SRC2_SRC2_CL_VANILLA_LCDM, rtol=1e-8) -def test_eval_cl_window_lens0_src2(): - tools = ModelingTools() - cosmo = pyccl.CosmologyVanillaLCDM() - params = ParamsMap() - +def test_eval_cl_window_lens0_src2() -> None: + tools = ModelingTools( + ccl_factory=CCLFactory(amplitude_parameter=PoweSpecAmplitudeParameter.SIGMA8) + ) + params = get_default_params_map(tools) tools.update(params) - tools.prepare(cosmo) + tools.prepare() sacc_data = sacc.Sacc.load_fits(SACC_FILE) src2 = wl.WeakLensing(sacc_tracer="src2") @@ -176,13 +178,13 @@ def test_eval_cl_window_lens0_src2(): assert_allclose(theory_vector, LENS0_SRC2_CL_VANILLA_LCDM, rtol=1e-8) -def test_eval_cl_window_lens0_lens0(): - tools = ModelingTools() - cosmo = pyccl.CosmologyVanillaLCDM() - params = ParamsMap() - +def test_eval_cl_window_lens0_lens0() -> None: + tools = ModelingTools( + ccl_factory=CCLFactory(amplitude_parameter=PoweSpecAmplitudeParameter.SIGMA8) + ) + params = get_default_params_map(tools) tools.update(params) - tools.prepare(cosmo) + tools.prepare() sacc_data = sacc.Sacc.load_fits(SACC_FILE) lens0 = wl.WeakLensing(sacc_tracer="lens0") @@ -200,13 +202,13 @@ def test_eval_cl_window_lens0_lens0(): assert_allclose(theory_vector, LENS0_LENS0_CL_VANILLA_LCDM, rtol=1e-8) -def test_compute_likelihood_src0_src0(): - tools = ModelingTools() - cosmo = pyccl.CosmologyVanillaLCDM() - params = ParamsMap() - +def test_compute_likelihood_src0_src0() -> None: + tools = ModelingTools( + ccl_factory=CCLFactory(amplitude_parameter=PoweSpecAmplitudeParameter.SIGMA8) + ) + params = get_default_params_map(tools) tools.update(params) - tools.prepare(cosmo) + tools.prepare() sacc_data = sacc.Sacc.load_fits(SACC_FILE) src0 = wl.WeakLensing(sacc_tracer="src0") diff --git a/tests/test_bug_413.py b/tests/test_bug_413.py index 1091a7fc..16bdfcb6 100644 --- a/tests/test_bug_413.py +++ b/tests/test_bug_413.py @@ -4,12 +4,12 @@ import pyccl import sacc +from firecrown.updatable import get_default_params_map from firecrown.likelihood.number_counts import ( NumberCounts, PTNonLinearBiasSystematic, ) from firecrown.likelihood.two_point import TwoPoint -from firecrown.parameters import ParamsMap from firecrown.modeling_tools import ModelingTools @@ -42,17 +42,6 @@ def make_twopoint_with_optional_systematics( ) statistic.read(sacc_data) # Note the two sources have identical parameters. - param_map = ParamsMap( - { - "lens0_bias": 1.1, - "lens0_b_2": 1.05, - "lens0_b_s": 0.99, - "lens1_bias": 1.1, - "lens1_b_2": 1.05, - "lens1_b_s": 0.99, - } - ) - statistic.update(param_map) tools = ModelingTools( pt_calculator=pyccl.nl_pt.EulerianPTCalculator( with_NC=True, @@ -62,8 +51,21 @@ def make_twopoint_with_optional_systematics( nk_per_decade=20, ) ) - tools.update(param_map) - tools.prepare(pyccl.CosmologyVanillaLCDM()) + params = get_default_params_map(tools) + params.update( + { + "lens0_bias": 1.1, + "lens0_b_2": 1.05, + "lens0_b_s": 0.99, + "lens1_bias": 1.1, + "lens1_b_2": 1.05, + "lens1_b_s": 0.99, + } + ) + + tools.update(params) + tools.prepare() + statistic.update(params) _ = a.get_tracers(tools) _ = b.get_tracers(tools) assert a.tracers[0].has_pt == first_source_has_systematic diff --git a/tests/test_ccl_factory.py b/tests/test_ccl_factory.py new file mode 100644 index 00000000..3e12d614 --- /dev/null +++ b/tests/test_ccl_factory.py @@ -0,0 +1,367 @@ +"""Test the CCLFactory object.""" + +import numpy as np +import pytest +import pyccl +from pyccl.neutrinos import NeutrinoMassSplits + +from firecrown.ccl_factory import ( + CCLFactory, + PoweSpecAmplitudeParameter, + CCLCalculatorArgs, +) +from firecrown.updatable import get_default_params_map +from firecrown.parameters import ParamsMap +from firecrown.utils import base_model_from_yaml, base_model_to_yaml + + +@pytest.fixture(name="amplitude_parameter", params=list(PoweSpecAmplitudeParameter)) +def fixture_amplitude_parameter(request): + return request.param + + +@pytest.fixture(name="neutrino_mass_splits", params=list(NeutrinoMassSplits)) +def fixture_neutrino_mass_splits(request): + return request.param + + +@pytest.fixture( + name="require_nonlinear_pk", + params=[True, False], + ids=["require_nonlinear_pk", "no_require_nonlinear_pk"], +) +def fixture_require_nonlinear_pk(request): + return request.param + + +Z_ARRAY = np.linspace(0.0, 5.0, 100) +A_ARRAY = 1.0 / (1.0 + np.flip(Z_ARRAY)) +K_ARRAY = np.geomspace(1.0e-5, 10.0, 100) +A_GRID, K_GRID = np.meshgrid(A_ARRAY, K_ARRAY, indexing="ij") + +CHI_ARRAY = np.linspace(100.0, 0.0, 100) +H_OVER_H0_ARRAY = np.linspace(1.0, 100.0, 100) +# Simple power spectrum model for testing +PK_ARRAY = ( + 2.0e-9 + * (K_GRID / 0.05) ** (0.96) + * (np.log(1 + 2.34 * K_GRID / 0.30) / (2.34 * K_GRID / 0.30)) + * (1 / (1 + (K_GRID / (0.1 * 0.05)) ** (3 * 0.96)) * A_GRID) +) + +BACKGROUND = { + "background": {"a": A_ARRAY, "chi": CHI_ARRAY, "h_over_h0": H_OVER_H0_ARRAY} +} + +PK_LINEAR = { + "pk_linear": {"k": K_ARRAY, "a": A_ARRAY, "delta_matter:delta_matter": PK_ARRAY} +} + +PK_NONLIN = { + "pk_nonlin": { + "k": np.linspace(0.1, 1.0, 100), + "a": np.linspace(0.1, 1.0, 100), + "delta_matter:delta_matter": PK_ARRAY, + } +} + + +@pytest.fixture( + name="calculator_args", + params=[BACKGROUND, BACKGROUND | PK_LINEAR, BACKGROUND | PK_LINEAR | PK_NONLIN], + ids=["background", "background_linear_pk", "background_linear_pk_nonlinear_pk"], +) +def fixture_calculator_args(request) -> CCLCalculatorArgs: + return request.param + + +def test_ccl_factory_simple( + amplitude_parameter: PoweSpecAmplitudeParameter, + neutrino_mass_splits: NeutrinoMassSplits, + require_nonlinear_pk: bool, +) -> None: + ccl_factory = CCLFactory( + amplitude_parameter=amplitude_parameter, + mass_split=neutrino_mass_splits, + require_nonlinear_pk=require_nonlinear_pk, + ) + + assert ccl_factory is not None + assert ccl_factory.amplitude_parameter == amplitude_parameter + assert ccl_factory.mass_split == neutrino_mass_splits + assert ccl_factory.require_nonlinear_pk == require_nonlinear_pk + + default_params = get_default_params_map(ccl_factory) + + ccl_factory.update(default_params) + + cosmo = ccl_factory.create() + + assert cosmo is not None + assert isinstance(cosmo, pyccl.Cosmology) + + +def test_ccl_factory_ccl_args( + amplitude_parameter: PoweSpecAmplitudeParameter, + neutrino_mass_splits: NeutrinoMassSplits, + require_nonlinear_pk: bool, + calculator_args: CCLCalculatorArgs, +) -> None: + ccl_factory = CCLFactory( + amplitude_parameter=amplitude_parameter, + mass_split=neutrino_mass_splits, + require_nonlinear_pk=require_nonlinear_pk, + ) + + if require_nonlinear_pk and "pk_linear" not in calculator_args: + pytest.skip( + "Nonlinear power spectrum requested but " + "linear power spectrum not provided." + ) + + assert ccl_factory is not None + assert ccl_factory.amplitude_parameter == amplitude_parameter + assert ccl_factory.mass_split == neutrino_mass_splits + assert ccl_factory.require_nonlinear_pk == require_nonlinear_pk + + default_params = get_default_params_map(ccl_factory) + + ccl_factory.update(default_params) + + cosmo = ccl_factory.create(calculator_args) + + assert cosmo is not None + assert isinstance(cosmo, pyccl.Cosmology) + + +def test_ccl_factory_update() -> None: + ccl_factory = CCLFactory() + + assert ccl_factory is not None + + default_params = get_default_params_map(ccl_factory) + + ccl_factory.update(default_params) + + cosmo = ccl_factory.create() + + assert cosmo is not None + assert isinstance(cosmo, pyccl.Cosmology) + + new_params = ParamsMap(default_params.copy()) + new_params["Omega_c"] = 0.1 + + ccl_factory.reset() + ccl_factory.update(new_params) + + cosmo = ccl_factory.create() + + assert cosmo is not None + assert isinstance(cosmo, pyccl.Cosmology) + + +def test_ccl_factory_amplitude_parameter( + amplitude_parameter: PoweSpecAmplitudeParameter, +) -> None: + ccl_factory = CCLFactory(amplitude_parameter=amplitude_parameter) + + assert ccl_factory is not None + + default_params = get_default_params_map(ccl_factory) + + ccl_factory.update(default_params) + + cosmo = ccl_factory.create() + + assert cosmo is not None + assert isinstance(cosmo, pyccl.Cosmology) + + +def test_ccl_factory_neutrino_mass_splits( + neutrino_mass_splits: NeutrinoMassSplits, +) -> None: + ccl_factory = CCLFactory(neutrino_mass_splits=neutrino_mass_splits) + + assert ccl_factory is not None + + default_params = get_default_params_map(ccl_factory) + + ccl_factory.update(default_params) + + cosmo = ccl_factory.create() + + assert cosmo is not None + assert isinstance(cosmo, pyccl.Cosmology) + + +def test_ccl_factory_get() -> None: + ccl_factory = CCLFactory() + + assert ccl_factory is not None + + default_params = get_default_params_map(ccl_factory) + + ccl_factory.update(default_params) + + cosmo = ccl_factory.create() + + assert cosmo is not None + assert isinstance(cosmo, pyccl.Cosmology) + + cosmo2 = ccl_factory.get() + + assert cosmo2 is not None + assert isinstance(cosmo2, pyccl.Cosmology) + + assert cosmo is cosmo2 + + +def test_ccl_factory_get_not_created() -> None: + ccl_factory = CCLFactory() + + with pytest.raises(ValueError, match="CCLFactory object has not been created yet."): + ccl_factory.get() + + +def test_ccl_factory_create_twice() -> None: + ccl_factory = CCLFactory() + + assert ccl_factory is not None + + default_params = get_default_params_map(ccl_factory) + + ccl_factory.update(default_params) + + cosmo = ccl_factory.create() + + assert cosmo is not None + assert isinstance(cosmo, pyccl.Cosmology) + + with pytest.raises(ValueError, match="CCLFactory object has already been created."): + ccl_factory.create() + + +def test_ccl_factory_create_not_updated() -> None: + ccl_factory = CCLFactory() + + assert ccl_factory is not None + + with pytest.raises(ValueError, match="Parameters have not been updated yet."): + ccl_factory.create() + + +def test_ccl_factory_tofrom_yaml() -> None: + ccl_factory = CCLFactory() + + assert ccl_factory is not None + + default_params = get_default_params_map(ccl_factory) + + ccl_factory.update(default_params) + + cosmo = ccl_factory.create() + + assert cosmo is not None + assert isinstance(cosmo, pyccl.Cosmology) + + yaml_str = base_model_to_yaml(ccl_factory) + + ccl_factory2 = base_model_from_yaml(CCLFactory, yaml_str) + + assert ccl_factory2 is not None + ccl_factory2.update(default_params) + + cosmo2 = ccl_factory2.create() + + assert cosmo2 is not None + assert isinstance(cosmo2, pyccl.Cosmology) + + assert cosmo == cosmo2 + + +def test_ccl_factory_tofrom_yaml_all_options( + amplitude_parameter: PoweSpecAmplitudeParameter, + neutrino_mass_splits: NeutrinoMassSplits, + require_nonlinear_pk: bool, +) -> None: + ccl_factory = CCLFactory( + amplitude_parameter=amplitude_parameter, + mass_split=neutrino_mass_splits, + require_nonlinear_pk=require_nonlinear_pk, + ) + + assert ccl_factory is not None + assert ccl_factory.amplitude_parameter == amplitude_parameter + assert ccl_factory.mass_split == neutrino_mass_splits + assert ccl_factory.require_nonlinear_pk == require_nonlinear_pk + + default_params = get_default_params_map(ccl_factory) + + ccl_factory.update(default_params) + + cosmo = ccl_factory.create() + + assert cosmo is not None + assert isinstance(cosmo, pyccl.Cosmology) + + yaml_str = base_model_to_yaml(ccl_factory) + + ccl_factory2 = base_model_from_yaml(CCLFactory, yaml_str) + + assert ccl_factory2 is not None + assert ccl_factory2.amplitude_parameter == amplitude_parameter + assert ccl_factory2.mass_split == neutrino_mass_splits + assert ccl_factory2.require_nonlinear_pk == require_nonlinear_pk + ccl_factory2.update(default_params) + + cosmo2 = ccl_factory2.create() + + assert cosmo2 is not None + assert isinstance(cosmo2, pyccl.Cosmology) + + assert cosmo == cosmo2 + + +def test_ccl_factory_invalid_amplitude_parameter() -> None: + with pytest.raises( + ValueError, + match=".*Invalid value for PoweSpecAmplitudeParameter: Im not a valid value.*", + ): + CCLFactory(amplitude_parameter="Im not a valid value") + + +def test_ccl_factory_invalid_mass_splits() -> None: + with pytest.raises( + ValueError, + match=".*Invalid value for NeutrinoMassSplits: Im not a valid value.*", + ): + CCLFactory(mass_split="Im not a valid value") + + +def test_ccl_factory_from_dict() -> None: + ccl_factory_dict = { + "amplitude_parameter": PoweSpecAmplitudeParameter.SIGMA8, + "mass_split": NeutrinoMassSplits.EQUAL, + "require_nonlinear_pk": True, + } + + ccl_factory = CCLFactory.model_validate(ccl_factory_dict) + + assert ccl_factory is not None + assert ccl_factory.amplitude_parameter == PoweSpecAmplitudeParameter.SIGMA8 + assert ccl_factory.mass_split == NeutrinoMassSplits.EQUAL + assert ccl_factory.require_nonlinear_pk is True + + +def test_ccl_factory_from_dict_wrong_type() -> None: + ccl_factory_dict = { + "amplitude_parameter": 0.32, + "mass_split": NeutrinoMassSplits.EQUAL, + "require_nonlinear_pk": True, + } + + with pytest.raises( + ValueError, + match=".*Input should be.*", + ): + CCLFactory.model_validate(ccl_factory_dict) diff --git a/tests/test_modeling_tools.py b/tests/test_modeling_tools.py index 4ec3959e..d546e8d7 100644 --- a/tests/test_modeling_tools.py +++ b/tests/test_modeling_tools.py @@ -4,8 +4,8 @@ import pytest import pyccl +from firecrown.updatable import get_default_params_map from firecrown.modeling_tools import ModelingTools, PowerspectrumModifier -from firecrown.parameters import ParamsMap @pytest.fixture(name="dummy_powerspectrum") @@ -15,7 +15,7 @@ def make_dummy_powerspectrum() -> pyccl.Pk2D: return pyccl.Pk2D.__new__(pyccl.Pk2D) -def test_default_constructed_state(): +def test_default_constructed_state() -> None: tools = ModelingTools() # Default constructed state is pretty barren... assert tools.ccl_cosmo is None @@ -24,23 +24,23 @@ def test_default_constructed_state(): assert len(tools.powerspectra) == 0 -def test_default_constructed_no_tools(): +def test_default_constructed_no_tools() -> None: tools = ModelingTools() with pytest.raises(RuntimeError): _ = tools.get_pk("nonesuch") -def test_adding_pk_and_getting(dummy_powerspectrum: pyccl.Pk2D): +def test_adding_pk_and_getting(dummy_powerspectrum: pyccl.Pk2D) -> None: tools = ModelingTools() tools.add_pk("silly", dummy_powerspectrum) - cosmo = pyccl.CosmologyVanillaLCDM() - tools.update(ParamsMap()) - tools.prepare(cosmo) + params = get_default_params_map(tools) + tools.update(params) + tools.prepare() assert tools.get_pk("silly") == dummy_powerspectrum -def test_no_adding_pk_twice(dummy_powerspectrum: pyccl.Pk2D): +def test_no_adding_pk_twice(dummy_powerspectrum: pyccl.Pk2D) -> None: tools = ModelingTools() tools.add_pk("silly", dummy_powerspectrum) assert tools.powerspectra["silly"] == dummy_powerspectrum @@ -48,38 +48,38 @@ def test_no_adding_pk_twice(dummy_powerspectrum: pyccl.Pk2D): tools.add_pk("silly", dummy_powerspectrum) -def test_modeling_tool_prepare_without_update(): +def test_modeling_tool_prepare_without_update() -> None: tools = ModelingTools() - cosmo = pyccl.CosmologyVanillaLCDM() with pytest.raises(RuntimeError, match="ModelingTools has not been updated."): - tools.prepare(cosmo) + tools.prepare() -def test_modeling_tool_preparing_twice(): +def test_modeling_tool_preparing_twice() -> None: tools = ModelingTools() - cosmo = pyccl.CosmologyVanillaLCDM() - tools.update(ParamsMap()) - tools.prepare(cosmo) + params = get_default_params_map(tools) + tools.update(params) + tools.prepare() with pytest.raises(RuntimeError, match="ModelingTools has already been prepared"): - tools.prepare(cosmo) + tools.prepare() -def test_modeling_tool_wrongly_setting_ccl_cosmo(): +def test_modeling_tool_wrongly_setting_ccl_cosmo() -> None: tools = ModelingTools() cosmo = pyccl.CosmologyVanillaLCDM() - tools.update(ParamsMap()) + params = get_default_params_map(tools) + tools.update(params) tools.ccl_cosmo = cosmo with pytest.raises(RuntimeError, match="Cosmology has already been set"): - tools.prepare(cosmo) + tools.prepare() -def test_modeling_tools_get_cosmo_without_setting(): +def test_modeling_tools_get_cosmo_without_setting() -> None: tools = ModelingTools() with pytest.raises(RuntimeError, match="Cosmology has not been set"): tools.get_ccl_cosmology() -def test_modeling_tools_get_pt_calculator_without_setting(): +def test_modeling_tools_get_pt_calculator_without_setting() -> None: tools = ModelingTools() with pytest.raises(RuntimeError, match="A PT calculator has not been set"): tools.get_pt_calculator() @@ -98,11 +98,11 @@ def compute_p_of_k_z(self, tools: ModelingTools) -> pyccl.Pk2D: return pyccl.Pk2D.__new__(pyccl.Pk2D) -def test_modeling_tools_add_pk_modifiers(): +def test_modeling_tools_add_pk_modifiers() -> None: tools = ModelingTools(pk_modifiers=[DummyPowerspectrumModifier()]) - cosmo = pyccl.CosmologyVanillaLCDM() - tools.update(ParamsMap()) - tools.prepare(cosmo) + params = get_default_params_map(tools) + tools.update(params) + tools.prepare() assert tools.get_pk("dummy:dummy") is not None assert isinstance(tools.get_pk("dummy:dummy"), pyccl.Pk2D) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 5041ee5f..5982bdd5 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -103,7 +103,7 @@ def test_get_params_names_does_not_allow_mutation(): def test_params_map(): - my_params = ParamsMap({"a": 1}) + my_params = ParamsMap({"a": 1.0}) x = my_params.get_from_prefix_param(None, "a") assert x == 1 with pytest.raises(KeyError): @@ -112,6 +112,20 @@ def test_params_map(): _ = my_params.get_from_prefix_param(None, "no_such_name") +def test_params_map_wrong_type(): + with pytest.raises( + TypeError, match="Value for parameter a is not a float or a list of floats.*" + ): + _ = ParamsMap({"a": "not a float or a list of floats"}) + + +def test_params_map_wrong_type_list(): + with pytest.raises( + TypeError, match="Value for parameter a is not a float or a list of floats.*" + ): + _ = ParamsMap({"a": ["not a float or a list of floats"]}) + + def test_parameter_get_full_name_reject_empty_name(): with pytest.raises(ValueError): _ = parameter_get_full_name(None, "") diff --git a/tests/test_pt_systematics.py b/tests/test_pt_systematics.py index 727005c4..8857c84f 100644 --- a/tests/test_pt_systematics.py +++ b/tests/test_pt_systematics.py @@ -12,6 +12,7 @@ import pyccl.nl_pt as pt import sacc +from firecrown.updatable import get_default_params_map import firecrown.likelihood.weak_lensing as wl import firecrown.likelihood.number_counts as nc from firecrown.likelihood.two_point import ( @@ -21,7 +22,7 @@ ) from firecrown.likelihood.gaussian import ConstGaussian from firecrown.modeling_tools import ModelingTools -from firecrown.parameters import ParamsMap +from firecrown.ccl_factory import CCLFactory, PoweSpecAmplitudeParameter @pytest.fixture(name="weak_lensing_source") @@ -85,9 +86,13 @@ def test_pt_systematics(weak_lensing_source, number_counts_source, sacc_data): nk_per_decade=4, cosmo=ccl_cosmo, ) - modeling_tools = ModelingTools(pt_calculator=pt_calculator) - modeling_tools.update(ParamsMap()) - modeling_tools.prepare(ccl_cosmo) + modeling_tools = ModelingTools( + pt_calculator=pt_calculator, + ccl_factory=CCLFactory(amplitude_parameter=PoweSpecAmplitudeParameter.SIGMA8), + ) + params = get_default_params_map(modeling_tools) + modeling_tools.update(params) + modeling_tools.prepare() # Bare CCL setup a_1 = 1.0 @@ -117,7 +122,7 @@ def test_pt_systematics(weak_lensing_source, number_counts_source, sacc_data): pk_gg = pt_calculator.get_biased_pk2d(tracer1=ptt_g, tracer2=ptt_g) # Set the parameters for our systematics - systematics_params = ParamsMap( + params.update( { "ia_a_1": a_1, "ia_a_2": a_2, @@ -132,7 +137,7 @@ def test_pt_systematics(weak_lensing_source, number_counts_source, sacc_data): ) # Apply the systematics parameters - likelihood.update(systematics_params) + likelihood.update(params) # Make things faster by only using a couple of ells for s in likelihood.statistics: @@ -219,7 +224,7 @@ def test_pt_systematics(weak_lensing_source, number_counts_source, sacc_data): cl_cs_theory = cl_GG + 2 * cl_GI + cl_II cl_gg_theory = cl_gg + 2 * cl_gm + cl_mm - print("IDS: ", id(s0), id(s1)) + # print("IDS: ", id(s0), id(s1)) assert np.allclose(cells_GG, cells_GG_m, atol=0, rtol=1e-127) @@ -282,9 +287,13 @@ def test_pt_mixed_systematics(sacc_data): nk_per_decade=4, cosmo=ccl_cosmo, ) - modeling_tools = ModelingTools(pt_calculator=pt_calculator) - modeling_tools.update(ParamsMap()) - modeling_tools.prepare(ccl_cosmo) + modeling_tools = ModelingTools( + pt_calculator=pt_calculator, + ccl_factory=CCLFactory(amplitude_parameter=PoweSpecAmplitudeParameter.SIGMA8), + ) + params = get_default_params_map(modeling_tools) + modeling_tools.update(params) + modeling_tools.prepare() # Bare CCL setup a_1 = 1.0 @@ -305,7 +314,7 @@ def test_pt_mixed_systematics(sacc_data): pk_mi = pt_calculator.get_biased_pk2d(tracer1=ptt_m, tracer2=ptt_i) # Set the parameters for our systematics - systematics_params = ParamsMap( + params.update( { "ia_a_1": a_1, "ia_a_2": a_2, @@ -316,7 +325,7 @@ def test_pt_mixed_systematics(sacc_data): ) # Apply the systematics parameters - likelihood.update(systematics_params) + likelihood.update(params) # Make things faster by only using a couple of ells for s in likelihood.statistics: diff --git a/tutorial/inferred_zdist.qmd b/tutorial/inferred_zdist.qmd index 42eb7380..dabb27ca 100644 --- a/tutorial/inferred_zdist.qmd +++ b/tutorial/inferred_zdist.qmd @@ -121,16 +121,16 @@ Markdown(f"```yaml\n{default_values_yaml}\n```") Now, we must configure the cosmology and prepare the two-point statistics for analysis: ```{python} -import pyccl - from firecrown.modeling_tools import ModelingTools +from firecrown.ccl_factory import CCLFactory +from firecrown.updatable import get_default_params_map +from firecrown.parameters import ParamsMap -ccl_cosmo = pyccl.CosmologyVanillaLCDM() -ccl_cosmo.compute_nonlin_power() +tools = ModelingTools(ccl_factory=CCLFactory(require_nonlinear_pk=True)) +params = get_default_params_map(tools, two_point_lens0_lens0) -tools = ModelingTools() tools.update(params) -tools.prepare(ccl_cosmo) +tools.prepare() two_point_lens0_lens0.update(params) theory_vector = two_point_lens0_lens0.compute_theory_vector(tools) diff --git a/tutorial/two_point_factories.qmd b/tutorial/two_point_factories.qmd index bfe8eb88..ef3b6c90 100644 --- a/tutorial/two_point_factories.qmd +++ b/tutorial/two_point_factories.qmd @@ -190,16 +190,17 @@ Markdown(f"```yaml\n{default_values_yaml}\n```") Finally, we can prepare both likelihoods and compare the loglike values. ```{python} -import pyccl - from firecrown.modeling_tools import ModelingTools +from firecrown.ccl_factory import CCLFactory +from firecrown.updatable import get_default_params_map +from firecrown.parameters import ParamsMap -ccl_cosmo = pyccl.CosmologyVanillaLCDM() -ccl_cosmo.compute_nonlin_power() +tools = ModelingTools(ccl_factory=CCLFactory(require_nonlinear_pk=True)) +params = get_default_params_map(tools, likelihood) tools = ModelingTools() tools.update(params) -tools.prepare(ccl_cosmo) +tools.prepare() likelihood.update(params) likelihood_ready.update(params) diff --git a/tutorial/two_point_generators.qmd b/tutorial/two_point_generators.qmd index dd38577d..1fa32ba4 100644 --- a/tutorial/two_point_generators.qmd +++ b/tutorial/two_point_generators.qmd @@ -148,14 +148,18 @@ all_two_point_functions = tp.TwoPoint.from_metadata( ## Setting Up the Two-Point Statistics Setting up the two-point statistics requires a cosmology and a set of parameters. -The parameters necessary in our analysis depend on the two-point statistic measurements and the systematics we are considering. -In the code below, we extract the required parameters for the two-point statistics and create a `ParamsMap` object with the parameters' default values: +The parameters necessary in our analysis depend on the cosmology, two-point statistic measurements and the systematics we are considering. +The `ModelingTools` class is used to manage the cosmology and general computation objects. +In the code below, we extract the required parameters for the cosmology and the two-point statistics and create a `ParamsMap` object with the parameters' default values: ```{python} +from firecrown.modeling_tools import ModelingTools +from firecrown.ccl_factory import CCLFactory +from firecrown.updatable import get_default_params from firecrown.parameters import ParamsMap -req_params = all_two_point_functions.required_parameters() -default_values = req_params.get_default_values() +tools = ModelingTools(ccl_factory=CCLFactory(require_nonlinear_pk=True)) +default_values = get_default_params(tools, all_two_point_functions) params = ParamsMap(default_values) ``` @@ -167,7 +171,7 @@ These default values are: # | code-fold: true import yaml -default_values_yaml = yaml.dump(default_values, default_flow_style=False) +default_values_yaml = yaml.dump(dict(params), default_flow_style=False) Markdown(f"```yaml\n{default_values_yaml}\n```") ``` @@ -175,16 +179,8 @@ Markdown(f"```yaml\n{default_values_yaml}\n```") Lastly, we must configure the cosmology and prepare the two-point statistics for analysis: ```{python} -import pyccl - -from firecrown.modeling_tools import ModelingTools - -ccl_cosmo = pyccl.CosmologyVanillaLCDM() -ccl_cosmo.compute_nonlin_power() - -tools = ModelingTools() tools.update(params) -tools.prepare(ccl_cosmo) +tools.prepare() all_two_point_functions.update(params) ```