diff --git a/firecrown/likelihood/number_counts.py b/firecrown/likelihood/number_counts.py index fe4b528f..ecc7406f 100644 --- a/firecrown/likelihood/number_counts.py +++ b/firecrown/likelihood/number_counts.py @@ -646,6 +646,7 @@ class NumberCountsFactory(BaseModel): per_bin_systematics: Sequence[NumberCountsSystematicFactory] global_systematics: Sequence[NumberCountsSystematicFactory] + include_rsd: bool = False def model_post_init(self, _) -> None: """Initialize the NumberCountsFactory. @@ -674,7 +675,9 @@ def create(self, inferred_zdist: InferredGalaxyZDist) -> NumberCounts: ] systematics.extend(self._global_systematics_instances) - nc = NumberCounts.create_ready(inferred_zdist, systematics=systematics) + nc = NumberCounts.create_ready( + inferred_zdist, systematics=systematics, has_rsd=self.include_rsd + ) self._cache[inferred_zdist_id] = nc return nc @@ -697,7 +700,9 @@ def create_from_metadata_only( ] systematics.extend(self._global_systematics_instances) - nc = NumberCounts(sacc_tracer=sacc_tracer, systematics=systematics) + nc = NumberCounts( + sacc_tracer=sacc_tracer, systematics=systematics, has_rsd=self.include_rsd + ) self._cache[sacc_tracer_id] = nc return nc diff --git a/tests/likelihood/gauss_family/statistic/test_two_point.py b/tests/likelihood/gauss_family/statistic/test_two_point.py index f71fda8b..599769c3 100644 --- a/tests/likelihood/gauss_family/statistic/test_two_point.py +++ b/tests/likelihood/gauss_family/statistic/test_two_point.py @@ -45,6 +45,12 @@ from firecrown.data_types import TwoPointMeasurement +@pytest.fixture(name="include_rsd", params=[True, False], ids=["rsd", "no_rsd"]) +def fixture_include_rsd(request) -> bool: + """Return whether to include RSD in the test.""" + return request.param + + @pytest.fixture(name="source_0") def fixture_source_0() -> NumberCounts: """Return an almost-default NumberCounts source.""" @@ -98,20 +104,60 @@ def fixture_harmonic_data_no_window(harmonic_two_point_xy) -> TwoPointMeasuremen return tpm -def test_ell_for_xi_no_rounding(): +@pytest.fixture(name="wl_factory") +def fixture_wl_factory() -> WeakLensingFactory: + """Return a WeakLensingFactory object.""" + return WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) + + +@pytest.fixture(name="nc_factory") +def fixture_nc_factory(include_rsd: bool) -> NumberCountsFactory: + """Return a NumberCountsFactory object.""" + return NumberCountsFactory( + per_bin_systematics=[], global_systematics=[], include_rsd=include_rsd + ) + + +@pytest.fixture(name="two_point_with_window") +def fixture_two_point_with_window( + harmonic_data_with_window: TwoPointMeasurement, + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +) -> TwoPoint: + """Return a TwoPoint object with a window.""" + two_points = TwoPoint.from_measurement( + [harmonic_data_with_window], wl_factory=wl_factory, nc_factory=nc_factory + ) + return two_points.pop() + + +@pytest.fixture(name="two_point_without_window") +def fixture_two_point_without_window( + harmonic_data_no_window: TwoPointMeasurement, + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +) -> TwoPoint: + """Return a TwoPoint object without a window.""" + two_points = TwoPoint.from_measurement( + [harmonic_data_no_window], wl_factory=wl_factory, nc_factory=nc_factory + ) + return two_points.pop() + + +def test_ell_for_xi_no_rounding() -> None: res = _ell_for_xi(minimum=0, midpoint=5, maximum=80, n_log=5) expected = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0, 40.0, 80.0]) assert res.shape == expected.shape assert np.allclose(expected, res) -def test_ell_for_xi_doing_rounding(): +def test_ell_for_xi_doing_rounding() -> None: res = _ell_for_xi(minimum=1, midpoint=3, maximum=100, n_log=5) expected = np.array([1.0, 2.0, 3.0, 7.0, 17.0, 42.0, 100.0]) assert np.allclose(expected, res) -def test_compute_theory_vector(source_0: NumberCounts): +def test_compute_theory_vector(source_0: NumberCounts) -> None: # To create the TwoPoint object we need at least one source. statistic = TwoPoint("galaxy_density_xi", source_0, source_0) assert isinstance(statistic, TwoPoint) @@ -122,7 +168,7 @@ def test_compute_theory_vector(source_0: NumberCounts): # assert isinstance(prediction, TheoryVector) -def test_tracer_names(): +def test_tracer_names() -> None: assert TracerNames("", "") == TRACER_NAMES_TOTAL tn1 = TracerNames("cow", "pig") @@ -137,7 +183,7 @@ def test_tracer_names(): _ = tn1[2] -def test_two_point_src0_src0_window(sacc_galaxy_cells_src0_src0_window): +def test_two_point_src0_src0_window(sacc_galaxy_cells_src0_src0_window) -> None: """This test also makes sure that TwoPoint theory calculations are repeatable.""" sacc_data, _, _ = sacc_galaxy_cells_src0_src0_window @@ -166,7 +212,7 @@ def test_two_point_src0_src0_window(sacc_galaxy_cells_src0_src0_window): assert np.array_equal(result1, result2) -def test_two_point_src0_src0_no_window(sacc_galaxy_cells_src0_src0_no_window): +def test_two_point_src0_src0_no_window(sacc_galaxy_cells_src0_src0_no_window) -> None: """This test also makes sure that TwoPoint theory calculations are repeatable.""" sacc_data, _, _ = sacc_galaxy_cells_src0_src0_no_window @@ -195,7 +241,7 @@ def test_two_point_src0_src0_no_window(sacc_galaxy_cells_src0_src0_no_window): assert np.array_equal(result1, result2) -def test_two_point_src0_src0_no_data_lin(sacc_galaxy_cells_src0_src0_no_data): +def test_two_point_src0_src0_no_data_lin(sacc_galaxy_cells_src0_src0_no_data) -> None: sacc_data, _, _ = sacc_galaxy_cells_src0_src0_no_data src0 = WeakLensing(sacc_tracer="src0") @@ -224,7 +270,7 @@ def test_two_point_src0_src0_no_data_lin(sacc_galaxy_cells_src0_src0_no_data): assert all(statistic.ells <= 100) -def test_two_point_src0_src0_no_data_log(sacc_galaxy_cells_src0_src0_no_data): +def test_two_point_src0_src0_no_data_log(sacc_galaxy_cells_src0_src0_no_data) -> None: sacc_data, _, _ = sacc_galaxy_cells_src0_src0_no_data src0 = WeakLensing(sacc_tracer="src0") @@ -253,7 +299,7 @@ def test_two_point_src0_src0_no_data_log(sacc_galaxy_cells_src0_src0_no_data): assert all(statistic.ells <= 100) -def test_two_point_lens0_lens0_no_data(sacc_galaxy_xis_lens0_lens0_no_data): +def test_two_point_lens0_lens0_no_data(sacc_galaxy_xis_lens0_lens0_no_data) -> None: sacc_data, _, _ = sacc_galaxy_xis_lens0_lens0_no_data src0 = NumberCounts(sacc_tracer="lens0") @@ -282,7 +328,7 @@ def test_two_point_lens0_lens0_no_data(sacc_galaxy_xis_lens0_lens0_no_data): assert all(statistic.thetas <= 1.0) -def test_two_point_src0_src0_cuts(sacc_galaxy_cells_src0_src0): +def test_two_point_src0_src0_cuts(sacc_galaxy_cells_src0_src0) -> None: sacc_data, _, _ = sacc_galaxy_cells_src0_src0 src0 = WeakLensing(sacc_tracer="src0") @@ -309,7 +355,7 @@ def test_two_point_src0_src0_cuts(sacc_galaxy_cells_src0_src0): statistic.compute_theory_vector(tools) -def test_two_point_lens0_lens0_cuts(sacc_galaxy_xis_lens0_lens0): +def test_two_point_lens0_lens0_cuts(sacc_galaxy_xis_lens0_lens0) -> None: sacc_data, _, _ = sacc_galaxy_xis_lens0_lens0 src0 = NumberCounts(sacc_tracer="lens0") @@ -334,7 +380,7 @@ def test_two_point_lens0_lens0_cuts(sacc_galaxy_xis_lens0_lens0): statistic.compute_theory_vector(tools) -def test_two_point_lens0_lens0_config(sacc_galaxy_xis_lens0_lens0): +def test_two_point_lens0_lens0_config(sacc_galaxy_xis_lens0_lens0) -> None: sacc_data, _, _ = sacc_galaxy_xis_lens0_lens0 src0 = NumberCounts(sacc_tracer="lens0") @@ -365,7 +411,7 @@ def test_two_point_lens0_lens0_config(sacc_galaxy_xis_lens0_lens0): statistic.compute_theory_vector(tools) -def test_two_point_src0_src0_no_data_error(sacc_galaxy_cells_src0_src0_no_data): +def test_two_point_src0_src0_no_data_error(sacc_galaxy_cells_src0_src0_no_data) -> None: sacc_data, _, _ = sacc_galaxy_cells_src0_src0_no_data src0 = WeakLensing(sacc_tracer="src0") @@ -382,7 +428,9 @@ def test_two_point_src0_src0_no_data_error(sacc_galaxy_cells_src0_src0_no_data): statistic.read(sacc_data) -def test_two_point_lens0_lens0_no_data_error(sacc_galaxy_xis_lens0_lens0_no_data): +def test_two_point_lens0_lens0_no_data_error( + sacc_galaxy_xis_lens0_lens0_no_data, +) -> None: sacc_data, _, _ = sacc_galaxy_xis_lens0_lens0_no_data src0 = NumberCounts(sacc_tracer="lens0") @@ -399,7 +447,9 @@ def test_two_point_lens0_lens0_no_data_error(sacc_galaxy_xis_lens0_lens0_no_data statistic.read(sacc_data) -def test_two_point_src0_src0_data_and_conf_warn(sacc_galaxy_cells_src0_src0_window): +def test_two_point_src0_src0_data_and_conf_warn( + sacc_galaxy_cells_src0_src0_window, +) -> None: sacc_data, _, _ = sacc_galaxy_cells_src0_src0_window src0 = WeakLensing(sacc_tracer="src0") @@ -422,7 +472,7 @@ def test_two_point_src0_src0_data_and_conf_warn(sacc_galaxy_cells_src0_src0_wind statistic.read(sacc_data) -def test_two_point_lens0_lens0_data_and_conf_warn(sacc_galaxy_xis_lens0_lens0): +def test_two_point_lens0_lens0_data_and_conf_warn(sacc_galaxy_xis_lens0_lens0) -> None: sacc_data, _, _ = sacc_galaxy_xis_lens0_lens0 src0 = NumberCounts(sacc_tracer="lens0") @@ -445,22 +495,26 @@ def test_two_point_lens0_lens0_data_and_conf_warn(sacc_galaxy_xis_lens0_lens0): statistic.read(sacc_data) -def test_use_source_factory(harmonic_bin_1: InferredGalaxyZDist): - wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) - nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) - +def test_use_source_factory( + harmonic_bin_1: InferredGalaxyZDist, + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +) -> None: measurement = list(harmonic_bin_1.measurements)[0] source = use_source_factory(harmonic_bin_1, measurement, wl_factory, nc_factory) if measurement in GALAXY_LENS_TYPES: assert isinstance(source, NumberCounts) + assert source.has_rsd == nc_factory.include_rsd elif measurement in GALAXY_SOURCE_TYPES: assert isinstance(source, WeakLensing) else: assert False, f"Unknown measurement type: {measurement}" -def test_use_source_factory_invalid_measurement(harmonic_bin_1: InferredGalaxyZDist): +def test_use_source_factory_invalid_measurement( + harmonic_bin_1: InferredGalaxyZDist, +) -> None: with pytest.raises( ValueError, match="Measurement .* not found in inferred galaxy redshift distribution .*", @@ -468,41 +522,45 @@ def test_use_source_factory_invalid_measurement(harmonic_bin_1: InferredGalaxyZD use_source_factory(harmonic_bin_1, Galaxies.SHEAR_MINUS, None, None) -def test_use_source_factory_metadata_only_counts(): - wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) - nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) +def test_use_source_factory_metadata_only_counts( + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +) -> None: source = use_source_factory_metadata_index( "bin1", Galaxies.COUNTS, wl_factory=wl_factory, nc_factory=nc_factory ) assert isinstance(source, NumberCounts) + assert source.has_rsd == nc_factory.include_rsd -def test_use_source_factory_metadata_only_shear(): - wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) - nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) +def test_use_source_factory_metadata_only_shear( + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +) -> None: source = use_source_factory_metadata_index( "bin1", Galaxies.SHEAR_E, wl_factory=wl_factory, nc_factory=nc_factory ) assert isinstance(source, WeakLensing) -def test_use_source_factory_metadata_only_invalid_measurement(): - wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) - nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) +def test_use_source_factory_metadata_only_invalid_measurement( + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +) -> None: with pytest.raises(ValueError, match="Unknown measurement type encountered .*"): use_source_factory_metadata_index( "bin1", 120, wl_factory=wl_factory, nc_factory=nc_factory # type: ignore ) -def test_two_point_wrong_type(): +def test_two_point_wrong_type() -> None: with pytest.raises(ValueError, match="The SACC data type cow is not supported!"): TwoPoint( "cow", WeakLensing(sacc_tracer="calma"), WeakLensing(sacc_tracer="fernando") ) -def test_from_metadata_harmonic_wrong_metadata(): +def test_from_metadata_harmonic_wrong_metadata() -> None: with pytest.raises( ValueError, match=re.escape("Metadata of type is not supported") ): @@ -511,7 +569,7 @@ def test_from_metadata_harmonic_wrong_metadata(): ) -def test_use_source_factory_metadata_only_wrong_measurement(): +def test_use_source_factory_metadata_only_wrong_measurement() -> None: unknown_type = MagicMock() unknown_type.configure_mock(__eq__=MagicMock(return_value=False)) @@ -521,9 +579,10 @@ def test_use_source_factory_metadata_only_wrong_measurement(): ) -def test_from_metadata_only_harmonic(): - wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) - nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) +def test_from_metadata_only_harmonic( + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +) -> None: metadata: TwoPointHarmonicIndex = { "data_type": "galaxy_density_xi", "tracer_names": TracerNames("lens0", "lens0"), @@ -537,9 +596,10 @@ def test_from_metadata_only_harmonic(): assert not two_point.ready -def test_from_metadata_only_real(): - wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) - nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) +def test_from_metadata_only_real( + wl_factory: WeakLensingFactory, + nc_factory: NumberCountsFactory, +) -> None: metadata: TwoPointRealIndex = { "data_type": "galaxy_shear_xi_plus", "tracer_names": TracerNames("src0", "src0"), @@ -554,50 +614,29 @@ def test_from_metadata_only_real(): def test_from_measurement_compute_theory_vector_window( - harmonic_data_with_window: TwoPointMeasurement, -): - wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) - nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) - two_points = TwoPoint.from_measurement( - [harmonic_data_with_window], wl_factory=wl_factory, nc_factory=nc_factory - ) - two_point: TwoPoint = two_points.pop() - - assert isinstance(two_point, TwoPoint) - assert two_point.ready + two_point_with_window: TwoPoint, +) -> None: + assert isinstance(two_point_with_window, TwoPoint) + assert two_point_with_window.ready - req_params = two_point.required_parameters() + req_params = two_point_with_window.required_parameters() default_values = req_params.get_default_values() params = ParamsMap(default_values) tools = ModelingTools() tools.update(params) tools.prepare(pyccl.CosmologyVanillaLCDM()) - two_point.update(params) + two_point_with_window.update(params) - prediction = two_point.compute_theory_vector(tools) + prediction = two_point_with_window.compute_theory_vector(tools) assert isinstance(prediction, TheoryVector) assert prediction.shape == (4,) def test_from_measurement_compute_theory_vector_window_check( - harmonic_data_with_window: TwoPointMeasurement, - harmonic_data_no_window: TwoPointMeasurement, -): - wl_factory = WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) - nc_factory = NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) - - two_points_with_window = TwoPoint.from_measurement( - [harmonic_data_with_window], wl_factory=wl_factory, nc_factory=nc_factory - ) - two_point_with_window: TwoPoint = two_points_with_window.pop() - - two_points_without_window = TwoPoint.from_measurement( - [harmonic_data_no_window], wl_factory=wl_factory, nc_factory=nc_factory - ) - two_point_without_window: TwoPoint = two_points_without_window.pop() - + two_point_with_window: TwoPoint, two_point_without_window: TwoPoint +) -> None: assert isinstance(two_point_with_window, TwoPoint) assert two_point_with_window.ready