From 3bec9be4c287a9f9654b4b36151b295b693bc156 Mon Sep 17 00:00:00 2001 From: Sandro Dias Pinto Vitenti Date: Sat, 15 Oct 2022 22:26:05 -0300 Subject: [PATCH] Derived parameters support (#189) * Initial version of DerivedParameters framework. * Finished first working version of DerivedParameters framework. Included support for derived parameters in the CosmoSIS connector and included its use in the des_y1_3x2pt example. * Improving iterable in DerivedParameterCollection. Fixed minor flake8/mypy/... complaints. Added black --check to the CI. * Fixed pylint issues in tests. * Update documentation of _get_derived_parameters Co-authored-by: Marc Paterno --- .github/workflows/ci.yml | 5 + environment.yml | 1 + examples/des_y1_3x2pt/des_y1_3x2pt.ini | 2 + examples/des_y1_3x2pt/des_y1_3x2pt.py | 2 +- firecrown/connector/cosmosis/likelihood.py | 5 + .../likelihood/gauss_family/gauss_family.py | 40 ++++-- firecrown/likelihood/gauss_family/gaussian.py | 6 +- .../statistic/source/number_counts.py | 40 +++++- .../statistic/source/weak_lensing.py | 29 ++++- .../gauss_family/statistic/supernova.py | 10 +- .../gauss_family/statistic/two_point.py | 9 +- .../likelihood/gauss_family/student_t.py | 6 +- firecrown/likelihood/likelihood.py | 2 +- firecrown/parameters.py | 120 +++++++++++++++++- firecrown/updatable.py | 44 ++++++- tests/likelihood/lkdir/lkmodule.py | 9 +- tests/test_parameters.py | 97 +++++++++++++- tests/test_updatable.py | 25 +++- 18 files changed, 417 insertions(+), 35 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 61bc039ff..a0e2f9f1d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -67,6 +67,11 @@ jobs: shell: bash -l {0} run: python -m pip install cobaya if: steps.cache.outputs.cache-hit != 'true' + - name: Running black check + shell: bash -l {0} + run: | + black --check firecrown + black --check tests - name: Running flake8 shell: bash -l {0} run: flake8 firecrown diff --git a/environment.yml b/environment.yml index a9301f77d..9c9bcf26e 100644 --- a/environment.yml +++ b/environment.yml @@ -4,6 +4,7 @@ channels: dependencies: - cosmosis - cosmosis-build-standard-library + - black - flake8 - mypy - pylint diff --git a/examples/des_y1_3x2pt/des_y1_3x2pt.ini b/examples/des_y1_3x2pt/des_y1_3x2pt.ini index 3b385bf20..1936e7593 100644 --- a/examples/des_y1_3x2pt/des_y1_3x2pt.ini +++ b/examples/des_y1_3x2pt/des_y1_3x2pt.ini @@ -17,6 +17,7 @@ likelihoods = firecrown quiet = T debug = T timing = T +extra_output = TwoPoint/NumberCountsScale_lens0 TwoPoint/NumberCountsScale_lens1 TwoPoint/NumberCountsScale_lens2 TwoPoint/NumberCountsScale_lens3 TwoPoint/NumberCountsScale_lens4 [consistency] file = ${CSL_DIR}/utility/consistency/consistency_interface.py @@ -51,6 +52,7 @@ save_dir = des_y1_3x2pt_output [metropolis] samples = 1000 +nsteps = 1 [emcee] walkers = 64 diff --git a/examples/des_y1_3x2pt/des_y1_3x2pt.py b/examples/des_y1_3x2pt/des_y1_3x2pt.py index ac6fb6ef0..b63965fc0 100644 --- a/examples/des_y1_3x2pt/des_y1_3x2pt.py +++ b/examples/des_y1_3x2pt/des_y1_3x2pt.py @@ -60,7 +60,7 @@ """ The source is created and saved (temporarely in the sources dict). """ - sources[f"lens{i}"] = nc.NumberCounts(sacc_tracer=f"lens{i}", systematics=[pzshift]) + sources[f"lens{i}"] = nc.NumberCounts(sacc_tracer=f"lens{i}", systematics=[pzshift], derived_scale=True) """ Now that we have all sources we can instantiate all the two-point diff --git a/firecrown/connector/cosmosis/likelihood.py b/firecrown/connector/cosmosis/likelihood.py index aaddc9bfc..ef65918e6 100644 --- a/firecrown/connector/cosmosis/likelihood.py +++ b/firecrown/connector/cosmosis/likelihood.py @@ -75,10 +75,15 @@ def execute(self, sample: cosmosis.datablock): self.likelihood.update(firecrown_params) loglike = self.likelihood.compute_loglike(cosmo) + derived_params_collection = self.likelihood.get_derived_parameters() + assert derived_params_collection is not None self.likelihood.reset() sample.put_double(section_names.likelihoods, "firecrown_like", loglike) + for section, name, val in derived_params_collection: + sample.put(section, name, val) + # Save concatenated data vector and inverse covariance to enable support # for the CosmoSIS fisher sampler. sample.put( diff --git a/firecrown/likelihood/gauss_family/gauss_family.py b/firecrown/likelihood/gauss_family/gauss_family.py index dbd11e790..477a8cffc 100644 --- a/firecrown/likelihood/gauss_family/gauss_family.py +++ b/firecrown/likelihood/gauss_family/gauss_family.py @@ -21,7 +21,7 @@ from ..likelihood import Likelihood from ...updatable import UpdatableCollection from .statistic.statistic import Statistic -from ...parameters import ParamsMap, RequiredParameters +from ...parameters import ParamsMap, RequiredParameters, DerivedParameterCollection class GaussFamily(Likelihood): @@ -39,7 +39,7 @@ def __init__(self, statistics: List[Statistic]): self.inv_cov: Optional[np.ndarray] = None def read(self, sacc_data: sacc.Sacc) -> None: - """Read the covariance matrirx for this likelihood from the SACC file.""" + """Read the covariance matrix for this likelihood from the SACC file.""" _sd = sacc_data.copy() inds_list = [] @@ -59,21 +59,24 @@ def read(self, sacc_data: sacc.Sacc) -> None: @final def compute_chisq(self, cosmo: pyccl.Cosmology) -> float: """Calculate and return the chi-squared for the given cosmology.""" - residuals = [] - theory_vector = [] - data_vector = [] + residuals_list: List[np.ndarray] = [] + theory_vector_list: List[np.ndarray] = [] + data_vector_list: List[np.ndarray] = [] for stat in self.statistics: data, theory = stat.compute(cosmo) - residuals.append(np.atleast_1d(data - theory)) - theory_vector.append(np.atleast_1d(theory)) - data_vector.append(np.atleast_1d(data)) + residuals_list.append(np.atleast_1d(data - theory)) + theory_vector_list.append(np.atleast_1d(theory)) + data_vector_list.append(np.atleast_1d(data)) + + residuals = np.concatenate(residuals_list, axis=0) + self.predicted_data_vector: np.ndarray = np.concatenate(theory_vector_list) + self.measured_data_vector: np.ndarray = np.concatenate(data_vector_list) - residuals = np.concatenate(residuals, axis=0) - self.predicted_data_vector = np.concatenate(theory_vector) - self.measured_data_vector = np.concatenate(data_vector) # pylint: disable-next=C0103 x = scipy.linalg.solve_triangular(self.cholesky, residuals, lower=True) - return np.dot(x, x) + chisq = np.dot(x, x) + + return chisq @final def _update(self, params: ParamsMap) -> None: @@ -93,6 +96,15 @@ def _reset(self) -> None: self._reset_gaussian_family() self.statistics.reset() + @final + def _get_derived_parameters(self) -> DerivedParameterCollection: + derived_parameters = ( + self._get_derived_parameters_gaussian_family() + + self.statistics.get_derived_parameters() + ) + + return derived_parameters + @abstractmethod def _update_gaussian_family(self, params: ParamsMap) -> None: """Abstract method to update GaussianFamily state. Must be implemented by all @@ -120,3 +132,7 @@ def required_parameters(self) -> RequiredParameters: @abstractmethod def required_parameters_gaussian_family(self): """Required parameters for GaussFamily subclasses.""" + + @abstractmethod + def _get_derived_parameters_gaussian_family(self) -> DerivedParameterCollection: + """Get derived parameters for GaussFamily subclasses.""" diff --git a/firecrown/likelihood/gauss_family/gaussian.py b/firecrown/likelihood/gauss_family/gaussian.py index d6c660953..5dcf22396 100644 --- a/firecrown/likelihood/gauss_family/gaussian.py +++ b/firecrown/likelihood/gauss_family/gaussian.py @@ -13,7 +13,7 @@ import pyccl from .gauss_family import GaussFamily -from ...parameters import ParamsMap, RequiredParameters +from ...parameters import ParamsMap, RequiredParameters, DerivedParameterCollection class ConstGaussian(GaussFamily): @@ -35,3 +35,7 @@ def _reset_gaussian_family(self): @final def required_parameters_gaussian_family(self): return RequiredParameters([]) + + @final + def _get_derived_parameters_gaussian_family(self) -> DerivedParameterCollection: + return DerivedParameterCollection([]) diff --git a/firecrown/likelihood/gauss_family/statistic/source/number_counts.py b/firecrown/likelihood/gauss_family/statistic/source/number_counts.py index b8bc79794..fa0d4f657 100644 --- a/firecrown/likelihood/gauss_family/statistic/source/number_counts.py +++ b/firecrown/likelihood/gauss_family/statistic/source/number_counts.py @@ -13,7 +13,13 @@ from .source import Source from .source import Systematic -from .....parameters import ParamsMap, RequiredParameters, parameter_get_full_name +from .....parameters import ( + ParamsMap, + RequiredParameters, + parameter_get_full_name, + DerivedParameterScalar, + DerivedParameterCollection, +) from .....updatable import UpdatableCollection __all__ = ["NumberCounts"] @@ -87,6 +93,10 @@ def required_parameters(self) -> RequiredParameters: [parameter_get_full_name(self.sacc_tracer, pn) for pn in self.params_names] ) + @final + def _get_derived_parameters(self) -> DerivedParameterCollection: + return DerivedParameterCollection([]) + def apply( self, cosmo: pyccl.Cosmology, tracer_arg: NumberCountsArgs ) -> NumberCountsArgs: @@ -165,6 +175,10 @@ def required_parameters(self) -> RequiredParameters: [parameter_get_full_name(self.sacc_tracer, pn) for pn in self.params_names] ) + @final + def _get_derived_parameters(self) -> DerivedParameterCollection: + return DerivedParameterCollection([]) + def apply( self, cosmo: pyccl.Cosmology, tracer_arg: NumberCountsArgs ) -> NumberCountsArgs: @@ -227,6 +241,10 @@ def required_parameters(self) -> RequiredParameters: [parameter_get_full_name(self.sacc_tracer, pn) for pn in self.params_names] ) + @final + def _get_derived_parameters(self) -> DerivedParameterCollection: + return DerivedParameterCollection([]) + def apply(self, cosmo: pyccl.Cosmology, tracer_arg: NumberCountsArgs): """Apply a shift to the photo-z distribution of a source.""" @@ -260,6 +278,7 @@ def __init__( sacc_tracer: str, has_rsd: bool = False, has_mag_bias: bool = False, + derived_scale: bool = False, scale: float = 1.0, systematics: Optional[List[NumberCountsSystematic]] = None, ): @@ -268,6 +287,7 @@ def __init__( self.sacc_tracer = sacc_tracer self.has_rsd = has_rsd self.has_mag_bias = has_mag_bias + self.derived_scale = derived_scale self.systematics = UpdatableCollection([]) if systematics: @@ -313,6 +333,24 @@ def required_parameters(self) -> RequiredParameters: ) return rp + self.systematics.required_parameters() + @final + def _get_derived_parameters(self) -> DerivedParameterCollection: + if self.derived_scale: + assert self.current_tracer_args is not None + derived_scale = DerivedParameterScalar( + "TwoPoint", + f"NumberCountsScale_{self.sacc_tracer}", + self.current_tracer_args.scale, + ) + derived_parameters = DerivedParameterCollection([derived_scale]) + else: + derived_parameters = DerivedParameterCollection([]) + derived_parameters = ( + derived_parameters + self.systematics.get_derived_parameters() + ) + + return derived_parameters + def _read(self, sacc_data): """Read the data for this source from the SACC file. diff --git a/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py b/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py index 3848cb42f..6bfad20fd 100644 --- a/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py +++ b/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py @@ -13,7 +13,12 @@ from .source import Source from .source import Systematic -from .....parameters import ParamsMap, RequiredParameters, parameter_get_full_name +from .....parameters import ( + ParamsMap, + RequiredParameters, + parameter_get_full_name, + DerivedParameterCollection, +) from .....updatable import UpdatableCollection __all__ = ["WeakLensing"] @@ -82,6 +87,10 @@ def required_parameters(self) -> RequiredParameters: [parameter_get_full_name(self.sacc_tracer, pn) for pn in self.params_names] ) + @final + def _get_derived_parameters(self) -> DerivedParameterCollection: + return DerivedParameterCollection([]) + def apply(self, cosmo: pyccl.Cosmology, tracer_arg: WeakLensingArgs): """Apply multiplicative shear bias to a source. The `scale_` of the source is multiplied by `(1 + m)`. @@ -159,6 +168,10 @@ def required_parameters(self) -> RequiredParameters: [parameter_get_full_name(self.sacc_tracer, pn) for pn in self.params_names] ) + @final + def _get_derived_parameters(self) -> DerivedParameterCollection: + return DerivedParameterCollection([]) + def apply( self, cosmo: pyccl.Cosmology, tracer_arg: WeakLensingArgs ) -> WeakLensingArgs: @@ -208,6 +221,12 @@ def required_parameters(self) -> RequiredParameters: [parameter_get_full_name(self.sacc_tracer, pn) for pn in self.params_names] ) + @final + def _get_derived_parameters(self) -> DerivedParameterCollection: + derived_parameters = DerivedParameterCollection([]) + + return derived_parameters + def apply(self, cosmo: pyccl.Cosmology, tracer_arg: WeakLensingArgs): """Apply a shift to the photo-z distribution of a source.""" @@ -263,6 +282,14 @@ def _reset_source(self) -> None: def required_parameters(self) -> RequiredParameters: return self.systematics.required_parameters() + @final + def _get_derived_parameters(self) -> DerivedParameterCollection: + derived_parameters = DerivedParameterCollection([]) + derived_parameters = ( + derived_parameters + self.systematics.get_derived_parameters() + ) + return derived_parameters + def _read(self, sacc_data): """Read the data for this source from the SACC file. diff --git a/firecrown/likelihood/gauss_family/statistic/supernova.py b/firecrown/likelihood/gauss_family/statistic/supernova.py index 591989927..c1f186d33 100644 --- a/firecrown/likelihood/gauss_family/statistic/supernova.py +++ b/firecrown/likelihood/gauss_family/statistic/supernova.py @@ -2,7 +2,7 @@ """ from __future__ import annotations -from typing import Tuple, final +from typing import Optional, Tuple, final import numpy as np @@ -10,7 +10,7 @@ import sacc from .statistic import Statistic -from ....parameters import ParamsMap, RequiredParameters +from ....parameters import ParamsMap, RequiredParameters, DerivedParameterCollection class Supernova(Statistic): @@ -23,7 +23,7 @@ def __init__(self, sacc_tracer): self.sacc_tracer = sacc_tracer self.data_vector = None - self.a = None # pylint: disable-msg=invalid-name + self.a: Optional[np.ndarray] = None # pylint: disable-msg=invalid-name self.M = None # pylint: disable-msg=invalid-name def read(self, sacc_data: sacc.Sacc): @@ -55,6 +55,10 @@ def required_parameters(self) -> RequiredParameters: """ return RequiredParameters(["m"]) + @final + def _get_derived_parameters(self) -> DerivedParameterCollection: + return DerivedParameterCollection([]) + def compute(self, cosmo: pyccl.Cosmology) -> Tuple[np.ndarray, np.ndarray]: """Compute a two-point statistic from sources.""" diff --git a/firecrown/likelihood/gauss_family/statistic/two_point.py b/firecrown/likelihood/gauss_family/statistic/two_point.py index 33d6892f6..4e53a20de 100644 --- a/firecrown/likelihood/gauss_family/statistic/two_point.py +++ b/firecrown/likelihood/gauss_family/statistic/two_point.py @@ -14,7 +14,7 @@ from .statistic import Statistic from .source.source import Source, Systematic -from ....parameters import ParamsMap, RequiredParameters +from ....parameters import ParamsMap, RequiredParameters, DerivedParameterCollection # only supported types are here, any thing else will throw # a value error @@ -198,6 +198,13 @@ def _reset(self) -> None: def required_parameters(self) -> RequiredParameters: return self.source0.required_parameters() + self.source1.required_parameters() + @final + def _get_derived_parameters(self) -> DerivedParameterCollection: + derived_parameters = DerivedParameterCollection([]) + derived_parameters = derived_parameters + self.source0.get_derived_parameters() + derived_parameters = derived_parameters + self.source1.get_derived_parameters() + return derived_parameters + def read(self, sacc_data): """Read the data for this statistic from the SACC file. diff --git a/firecrown/likelihood/gauss_family/student_t.py b/firecrown/likelihood/gauss_family/student_t.py index d5410ce02..8b3baa9e7 100644 --- a/firecrown/likelihood/gauss_family/student_t.py +++ b/firecrown/likelihood/gauss_family/student_t.py @@ -11,7 +11,7 @@ from .gauss_family import GaussFamily from .statistic.statistic import Statistic -from ...parameters import ParamsMap, RequiredParameters +from ...parameters import ParamsMap, RequiredParameters, DerivedParameterCollection class StudentT(GaussFamily): @@ -50,3 +50,7 @@ def _reset_gaussian_family(self): @final def required_parameters_gaussian_family(self): return RequiredParameters([]) + + @final + def _get_derived_parameters_gaussian_family(self) -> DerivedParameterCollection: + return DerivedParameterCollection([]) diff --git a/firecrown/likelihood/likelihood.py b/firecrown/likelihood/likelihood.py index 803fe750b..e44e1f416 100644 --- a/firecrown/likelihood/likelihood.py +++ b/firecrown/likelihood/likelihood.py @@ -62,7 +62,7 @@ def get_params_names(self) -> Optional[List[str]]: @abstractmethod def read(self, sacc_data: sacc.Sacc): - """Read the covariance matrirx for this likelihood from the SACC file.""" + """Read the covariance matrix for this likelihood from the SACC file.""" @abstractmethod def compute_loglike(self, cosmo: pyccl.Cosmology) -> float: diff --git a/firecrown/parameters.py b/firecrown/parameters.py index 805aab1ee..3f527a5c3 100644 --- a/firecrown/parameters.py +++ b/firecrown/parameters.py @@ -3,7 +3,8 @@ """ from __future__ import annotations -from typing import Iterable, Dict, Set, Optional +from typing import Iterable, List, Dict, Set, Tuple, Optional, Iterator +from abc import ABC, abstractmethod def parameter_get_full_name(prefix: Optional[str], param: str) -> str: @@ -93,3 +94,120 @@ def get_params_names(self): for name in params_names_set: yield name + + +class DerivedParameter(ABC): + """Represents a derived parameter generated by an Updatable object + + This class provide the type that encapsulate a derived quantity computed + by an Updatable object during a statistical analysis. + """ + + def __init__(self, section: str, name: str): + """Constructs a new derived parameter.""" + self.section: str = section + self.name: str = name + + def get_full_name(self): + """Constructs the full name using section--name.""" + return f"{self.section}--{self.name}" + + @abstractmethod + def get_val(self): + """Returns the value contained.""" + + +class DerivedParameterScalar(DerivedParameter): + """Represents a derived scalar parameter generated by an Updatable object + + This class provide the type that encapsulate a derived scalar quantity (represented + by a float) computed by an Updatable object during a statistical analysis. + """ + + def __init__(self, section: str, name: str, val: float): + super().__init__(section, name) + + if not isinstance(val, float): + raise TypeError( + "DerivedParameterScalar expects a float but received a " + + str(type(val)) + ) + self.val: float = val + + def get_val(self) -> float: + return self.val + + +class DerivedParameterCollection: + """Represents a list of DerivedParameter objects.""" + + def __init__(self, derived_parameters: List[DerivedParameter]): + """Construct an instance from a List of DerivedParameter objects.""" + + if not all(isinstance(x, DerivedParameter) for x in derived_parameters): + raise TypeError( + "DerivedParameterCollection expects a list of DerivedParameter but " + "received a " + str([str(type(x)) for x in derived_parameters]) + ) + + self.derived_parameters: Dict[str, DerivedParameter] = {} + + for dp in derived_parameters: + self.add_required_parameter(dp) + + def __add__(self, other: Optional[DerivedParameterCollection]): + """Return a new DerivedParameterCollection with the lists of DerivedParameter + objects. + + If other is none return self. Otherwise, constructs a new object representing + the addition. + + Note that this function returns a new object that does not share state + with either argument to the addition operator.""" + if other is None: + return self + + return DerivedParameterCollection( + list(self.derived_parameters.values()) + + list(other.derived_parameters.values()) + ) + + def __eq__(self, other: object): + """Compare two DerivedParameterCollection objects for equality. + + This implementation raises a NotImplemented exception unless both + objects are DerivedParameterCollection objects. + + Two DerivedParameterCollection objects are equal if they contain the same + DerivedParameter objects. + """ + if not isinstance(other, DerivedParameterCollection): + return NotImplemented + return self.derived_parameters == other.derived_parameters + + def __iter__(self) -> Iterator[Tuple[str, str, float]]: + for derived_parameter in self.derived_parameters.values(): + yield ( + derived_parameter.section, + derived_parameter.name, + derived_parameter.get_val(), + ) + + def add_required_parameter(self, derived_parameter: DerivedParameter): + """Adds derived_parameter to the collection, it raises an ValueError if a + required parameter with the same name is already present in the collection. + """ + + required_parameter_full_name = derived_parameter.get_full_name() + if required_parameter_full_name in self.derived_parameters: + raise ValueError( + f"RequiredParameter named {required_parameter_full_name}" + f" is already present in the collection" + ) + else: + self.derived_parameters[required_parameter_full_name] = derived_parameter + + def get_derived_list(self) -> List[DerivedParameter]: + """Implement lazy iteration through the contained parameter names.""" + + return list(self.derived_parameters.values()) diff --git a/firecrown/updatable.py b/firecrown/updatable.py index 411c6f277..6d5d2cc4f 100644 --- a/firecrown/updatable.py +++ b/firecrown/updatable.py @@ -14,10 +14,11 @@ """ from __future__ import annotations -from typing import final +from typing import final, Optional from abc import ABC, abstractmethod from collections import UserList from .parameters import ParamsMap, RequiredParameters +from .parameters import DerivedParameterCollection class Updatable(ABC): @@ -34,6 +35,7 @@ class Updatable(ABC): def __init__(self): """Updatable initialization.""" self._updated: bool = False + self._returned_derived: bool = False @final def update(self, params: ParamsMap): @@ -50,6 +52,7 @@ def update(self, params: ParamsMap): def reset(self): """Reset self by calling the abstract _reset() method, and mark as reset.""" self._updated = False + self._returned_derived = False self._reset() @abstractmethod @@ -85,6 +88,30 @@ def required_parameters(self) -> RequiredParameters: # pragma: no cover """ return RequiredParameters([]) + @final + def get_derived_parameters( + self, + ) -> Optional[DerivedParameterCollection]: + """Returns a collection of derived parameters once per iteration of the + statistical analysis. First call returns the DerivedParameterCollection, + further calls return None. + """ + if not self._returned_derived: + self._returned_derived = True + return self._get_derived_parameters() + + return None + + @abstractmethod + def _get_derived_parameters(self) -> DerivedParameterCollection: + """Abstract method to be implemented by all concrete classes to return their + derived parameters. + + Concrete classes must override this. If no derived parameters are required + derived classes must simply return super()._get_derived_parameters(). + """ + return DerivedParameterCollection([]) + class UpdatableCollection(UserList): @@ -137,6 +164,21 @@ def required_parameters(self) -> RequiredParameters: return result + @final + def get_derived_parameters(self) -> Optional[DerivedParameterCollection]: + """Get all derived parameters if any.""" + has_any_derived = False + derived_parameters = DerivedParameterCollection([]) + for updatable in self: + derived_parameters0 = updatable.get_derived_parameters() + if derived_parameters0 is not None: + derived_parameters = derived_parameters + derived_parameters0 + has_any_derived = True + if has_any_derived: + return derived_parameters + else: + return None + def append(self, item: Updatable) -> None: """Append the given item to self. diff --git a/tests/likelihood/lkdir/lkmodule.py b/tests/likelihood/lkdir/lkmodule.py index 4879570be..b7bd5d8bd 100644 --- a/tests/likelihood/lkdir/lkmodule.py +++ b/tests/likelihood/lkdir/lkmodule.py @@ -1,6 +1,10 @@ import sacc import pyccl -from firecrown.parameters import ParamsMap, RequiredParameters +from firecrown.parameters import ( + ParamsMap, + RequiredParameters, + DerivedParameterCollection, +) from firecrown.likelihood.likelihood import Likelihood @@ -21,6 +25,9 @@ def _reset(self) -> None: def required_parameters(self): return RequiredParameters([]) + def _get_derived_parameters(self) -> DerivedParameterCollection: + return DerivedParameterCollection([]) + def compute_loglike(self, cosmo: pyccl.Cosmology) -> float: return -3.0 * self.placeholder diff --git a/tests/test_parameters.py b/tests/test_parameters.py index c8bda4269..3bb883f69 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -1,15 +1,21 @@ import pytest +import numpy as np from firecrown.parameters import RequiredParameters, parameter_get_full_name, ParamsMap +from firecrown.parameters import ( + DerivedParameterScalar, + DerivedParameterCollection, +) def test_get_params_names_does_not_allow_mutation(): - """The caller of RequiredParameters.get_params_names should not be able to modify the - state of the object on which the call was made.""" + """The caller of RequiredParameters.get_params_names should not be able to modify + the state of the object on which the call was made.""" orig = RequiredParameters(["a", "b"]) names = set(orig.get_params_names()) - assert names == set(["a", "b"]) + assert names == {"a", "b"} + assert names == {"b", "a"} names.add("c") - assert set(orig.get_params_names()) == set(["a", "b"]) + assert set(orig.get_params_names()) == {"a", "b"} def test_params_map(): @@ -19,7 +25,7 @@ def test_params_map(): with pytest.raises(KeyError): _ = my_params.get_from_prefix_param("no_such_prefix", "a") with pytest.raises(KeyError): - _ = my_params.get_from_prefix_param(None, "nosuchname") + _ = my_params.get_from_prefix_param(None, "no_such_name") def test_parameter_get_full_name_reject_empty_name(): @@ -37,5 +43,82 @@ def test_parameter_get_full_name_with_prefix(): def test_parameter_get_full_name_without_prefix(): - full_name = parameter_get_full_name(None, "nomen_mihi") - assert full_name == "nomen_mihi" + full_name = parameter_get_full_name(None, "nomen_foo") + assert full_name == "nomen_foo" + + +def test_derived_parameter_scalar(): + derived_param = DerivedParameterScalar("sec1", "name1", 3.14) + + assert isinstance(derived_param.get_val(), float) + assert derived_param.get_val() == 3.14 + assert derived_param.get_full_name() == "sec1--name1" + + +def test_derived_parameter_wrong_type(): + """Try instantiating DerivedParameter objects with wrong types.""" + + with pytest.raises(TypeError): + derived_param = DerivedParameterScalar( # pylint: disable-msg=E0110,W0612 + "sec1", "name1", "not a float" + ) + with pytest.raises(TypeError): + derived_param = DerivedParameterScalar( # pylint: disable-msg=E0110,W0612 + "sec1", "name1", [3.14] + ) + with pytest.raises(TypeError): + derived_param = DerivedParameterScalar( # pylint: disable-msg=E0110,W0612 + "sec1", "name1", np.array([3.14]) + ) + + +def test_derived_parameters_collection(): + olist = [ + DerivedParameterScalar("sec1", "name1", 3.14), + DerivedParameterScalar("sec2", "name2", 2.72), + ] + orig = DerivedParameterCollection(olist) + clist = orig.get_derived_list() + clist.append(DerivedParameterScalar("sec3", "name3", 0.58)) + assert orig.get_derived_list() == olist + + +def test_derived_parameters_collection_add(): + olist = [ + DerivedParameterScalar("sec1", "name1", 3.14), + DerivedParameterScalar("sec2", "name2", 2.72), + DerivedParameterScalar("sec2", "name3", 0.58), + ] + dpc1 = DerivedParameterCollection(olist) + dpc2 = None + + dpc = dpc1 + dpc2 + + for (section, name, val), derived_parameter in zip(dpc, olist): + assert section == derived_parameter.section + assert name == derived_parameter.name + assert val == derived_parameter.get_val() + + +def test_derived_parameters_collection_add_iter(): + olist1 = [ + DerivedParameterScalar("sec1", "name1", 3.14), + DerivedParameterScalar("sec2", "name2", 2.72), + DerivedParameterScalar("sec2", "name3", 0.58), + ] + dpc1 = DerivedParameterCollection(olist1) + + olist2 = [ + DerivedParameterScalar("sec3", "name1", 3.14e1), + DerivedParameterScalar("sec3", "name2", 2.72e1), + DerivedParameterScalar("sec3", "name3", 0.58e1), + ] + dpc2 = DerivedParameterCollection(olist2) + + dpc = dpc1 + dpc2 + olist = olist1 + olist2 + + for (section, name, val), derived_parameter in zip(dpc, olist): + assert section == derived_parameter.section + assert name == derived_parameter.name + assert val == derived_parameter.get_val() diff --git a/tests/test_updatable.py b/tests/test_updatable.py index fe2731d03..237c4e716 100644 --- a/tests/test_updatable.py +++ b/tests/test_updatable.py @@ -1,6 +1,10 @@ import pytest from firecrown.updatable import Updatable, UpdatableCollection -from firecrown.parameters import RequiredParameters, ParamsMap +from firecrown.parameters import ( + RequiredParameters, + ParamsMap, + DerivedParameterCollection, +) class Missing_update(Updatable): @@ -12,6 +16,9 @@ def required_parameters(self): # pragma: no cover def _reset(self) -> None: pass + def _get_derived_parameters(self) -> DerivedParameterCollection: + return DerivedParameterCollection([]) + class Missing_reset(Updatable): """A type that is abstract because it does not implement required_parameters.""" @@ -22,6 +29,9 @@ def _update(self, params): # pragma: no cover def required_parameters(self): # pragma: no cover pass + def _get_derived_parameters(self) -> DerivedParameterCollection: + return DerivedParameterCollection([]) + class Missing_required_parameters(Updatable): """A type that is abstract because it does not implement required_parameters.""" @@ -32,6 +42,9 @@ def _update(self, params): # pragma: no cover def _reset(self) -> None: pass + def _get_derived_parameters(self) -> DerivedParameterCollection: + return DerivedParameterCollection([]) + class MinimalUpdatable(Updatable): """A concrete time that implements Updatable.""" @@ -51,6 +64,9 @@ def _reset(self) -> None: def required_parameters(self): return RequiredParameters(["a"]) + def _get_derived_parameters(self) -> DerivedParameterCollection: + return DerivedParameterCollection([]) + class SimpleUpdatable(Updatable): """A concrete type that implements Updatable.""" @@ -72,6 +88,9 @@ def _reset(self) -> None: def required_parameters(self): return RequiredParameters(["x", "y"]) + def _get_derived_parameters(self) -> DerivedParameterCollection: + return DerivedParameterCollection([]) + def test_verify_abstract_interface(): with pytest.raises(TypeError): @@ -111,7 +130,7 @@ def test_updatable_collection_appends(): assert coll.required_parameters() == RequiredParameters(["x", "y", "a"]) -def test_updateable_collection_updates(): +def test_updatable_collection_updates(): coll = UpdatableCollection() assert len(coll) == 0 @@ -143,7 +162,7 @@ def test_updatable_collection_construction(): bad_list = [1] with pytest.raises(TypeError): - x = UpdatableCollection(bad_list) # pylint: disable-msg=W0612 + x = UpdatableCollection(bad_list) # pylint: disable-msg=W0612 def test_updatable_collection_insertion():