diff --git a/firecrown/likelihood/gauss_family/statistic/two_point.py b/firecrown/likelihood/gauss_family/statistic/two_point.py index 4573b3cf..1d050bc7 100644 --- a/firecrown/likelihood/gauss_family/statistic/two_point.py +++ b/firecrown/likelihood/gauss_family/statistic/two_point.py @@ -774,33 +774,3 @@ def calculate_pk( else: raise ValueError(f"No power spectrum for {pk_name} can be found.") return pk - - -class SourceFactory: - """A factory for creating sources.""" - - def __init__( - self, - wl_factory: WeakLensingFactory | None = None, - nc_factory: NumberCountsFactory | None = None, - ) -> None: - """Initialize the SourceFactory.""" - self.wl_factory = wl_factory - self.nc_factory = nc_factory - - def create(self, inferred_galaxy_zdist: InferredGalaxyZDist) -> Source: - """Create a source from the inferred galaxy redshift distribution.""" - match inferred_galaxy_zdist.measured_type: - - case GalaxyMeasuredType.COUNTS: - assert self.nc_factory is not None - return self.nc_factory.create(inferred_galaxy_zdist) - - case GalaxyMeasuredType.SHEAR_E | GalaxyMeasuredType.SHEAR_T: - assert self.wl_factory is not None - return self.wl_factory.create(inferred_galaxy_zdist) - case _: - raise ValueError( - f"Measured type {inferred_galaxy_zdist.measured_type} " - f"not supported!" - ) diff --git a/tests/metadata/test_metadata_two_point.py b/tests/metadata/test_metadata_two_point.py index 40321b5e..7ec0fd1c 100644 --- a/tests/metadata/test_metadata_two_point.py +++ b/tests/metadata/test_metadata_two_point.py @@ -28,28 +28,68 @@ type_to_sacc_string_real as real, Window, ) +from firecrown.likelihood.gauss_family.statistic.source.source import SourceGalaxy +import firecrown.likelihood.gauss_family.statistic.source.weak_lensing as wl +import firecrown.likelihood.gauss_family.statistic.source.number_counts as nc +from firecrown.likelihood.gauss_family.statistic.two_point import TwoPoint -@pytest.fixture(name="bin_1") -def make_bin_1() -> InferredGalaxyZDist: +@pytest.fixture( + name="harmonic_bin_1", + params=[GalaxyMeasuredType.COUNTS, GalaxyMeasuredType.SHEAR_E], +) +def make_harmonic_bin_1(request) -> InferredGalaxyZDist: """Generate an InferredGalaxyZDist object with 5 bins.""" x = InferredGalaxyZDist( bin_name="bin_1", z=np.linspace(0, 1, 5), dndz=np.array([0.1, 0.5, 0.2, 0.3, 0.4]), - measured_type=GalaxyMeasuredType.COUNTS, + measured_type=request.param, ) return x -@pytest.fixture(name="bin_2") -def make_bin_2() -> InferredGalaxyZDist: +@pytest.fixture( + name="harmonic_bin_2", + params=[GalaxyMeasuredType.COUNTS, GalaxyMeasuredType.SHEAR_E], +) +def make_harmonic_bin_2(request) -> InferredGalaxyZDist: """Generate an InferredGalaxyZDist object with 3 bins.""" x = InferredGalaxyZDist( bin_name="bin_2", z=np.linspace(0, 1, 3), dndz=np.array([0.1, 0.5, 0.4]), - measured_type=GalaxyMeasuredType.COUNTS, + measured_type=request.param, + ) + return x + + +@pytest.fixture( + name="real_bin_1", + params=[GalaxyMeasuredType.COUNTS, GalaxyMeasuredType.SHEAR_T], +) +def make_real_bin_1(request) -> InferredGalaxyZDist: + """Generate an InferredGalaxyZDist object with 5 bins.""" + x = InferredGalaxyZDist( + bin_name="bin_1", + z=np.linspace(0, 1, 5), + dndz=np.array([0.1, 0.5, 0.2, 0.3, 0.4]), + measured_type=request.param, + ) + return x + + +@pytest.fixture( + name="real_bin_2", + params=[GalaxyMeasuredType.COUNTS, GalaxyMeasuredType.SHEAR_T], +) +def make_real_bin_2(request) -> InferredGalaxyZDist: + """Generate an InferredGalaxyZDist object with 3 bins.""" + x = InferredGalaxyZDist( + bin_name="bin_2", + z=np.linspace(0, 1, 3), + dndz=np.array([0.1, 0.5, 0.4]), + measured_type=request.param, ) return x @@ -71,14 +111,26 @@ def make_window_1() -> Window: @pytest.fixture(name="two_point_cwindow_1") def make_two_point_cwindow_1( - window_1: Window, bin_1: InferredGalaxyZDist, bin_2: InferredGalaxyZDist + window_1: Window, + harmonic_bin_1: InferredGalaxyZDist, + harmonic_bin_2: InferredGalaxyZDist, ) -> TwoPointCWindow: """Generate a TwoPointCWindow object with 100 ells.""" - xy = TwoPointXY(x=bin_1, y=bin_2) + xy = TwoPointXY(x=harmonic_bin_1, y=harmonic_bin_2) two_point = TwoPointCWindow(XY=xy, window=window_1) return two_point +@pytest.fixture(name="wl_factory") +def make_wl_factory(): + return wl.WeakLensingFactory(per_bin_systematics=[], global_systematics=[]) + + +@pytest.fixture(name="nc_factory") +def make_nc_factory(): + return nc.NumberCountsFactory(per_bin_systematics=[], global_systematics=[]) + + def test_order_enums(): assert compare_enums(CMBMeasuredType.CONVERGENCE, ClusterMeasuredType.COUNTS) < 0 assert compare_enums(ClusterMeasuredType.COUNTS, CMBMeasuredType.CONVERGENCE) > 0 @@ -643,18 +695,18 @@ def test_measured_type_serialization(): assert t == recovered -def test_inferred_galaxy_zdist_serialization(bin_1: InferredGalaxyZDist): - s = bin_1.to_yaml() +def test_inferred_galaxy_zdist_serialization(harmonic_bin_1: InferredGalaxyZDist): + s = harmonic_bin_1.to_yaml() # Take a look at how hideous the generated string # is. recovered = InferredGalaxyZDist.from_yaml(s) - assert bin_1 == recovered + assert harmonic_bin_1 == recovered def test_two_point_xy_serialization( - bin_1: InferredGalaxyZDist, bin_2: InferredGalaxyZDist + harmonic_bin_1: InferredGalaxyZDist, harmonic_bin_2: InferredGalaxyZDist ): - xy = TwoPointXY(x=bin_1, y=bin_2) + xy = TwoPointXY(x=harmonic_bin_1, y=harmonic_bin_2) s = xy.to_yaml() # Take a look at how hideous the generated string # is. @@ -663,10 +715,10 @@ def test_two_point_xy_serialization( def test_two_point_cells_serialization( - bin_1: InferredGalaxyZDist, bin_2: InferredGalaxyZDist + harmonic_bin_1: InferredGalaxyZDist, harmonic_bin_2: InferredGalaxyZDist ): ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) - xy = TwoPointXY(x=bin_1, y=bin_2) + xy = TwoPointXY(x=harmonic_bin_1, y=harmonic_bin_2) cells = TwoPointCells(ells=ells, XY=xy) s = cells.to_yaml() recovered = TwoPointCells.from_yaml(s) @@ -686,11 +738,33 @@ def test_two_point_cwindow_serialization(two_point_cwindow_1: TwoPointCWindow): def test_two_point_xi_theta_serialization( - bin_1: InferredGalaxyZDist, bin_2: InferredGalaxyZDist + real_bin_1: InferredGalaxyZDist, real_bin_2: InferredGalaxyZDist ): - xy = TwoPointXY(x=bin_1, y=bin_2) + xy = TwoPointXY(x=real_bin_1, y=real_bin_2) theta = np.array(np.linspace(0, 10, 10)) xi_theta = TwoPointXiTheta(XY=xy, thetas=theta) s = xi_theta.to_yaml() recovered = TwoPointXiTheta.from_yaml(s) assert xi_theta == recovered + + +def test_two_point_from_metadata_cells( + harmonic_bin_1, harmonic_bin_2, wl_factory, nc_factory +): + ells = np.array(np.linspace(0, 100, 100), dtype=np.int64) + xy = TwoPointXY(x=harmonic_bin_1, y=harmonic_bin_2) + cells = TwoPointCells(ells=ells, XY=xy) + two_point = TwoPoint.from_metadata_cells([cells], wl_factory, nc_factory).pop() + + assert two_point is not None + assert isinstance(two_point, TwoPoint) + assert two_point.sacc_data_type == cells.get_sacc_name() + + assert isinstance(two_point.source0, SourceGalaxy) + assert isinstance(two_point.source1, SourceGalaxy) + + assert_array_equal(two_point.source0.tracer_args.z, harmonic_bin_1.z) + assert_array_equal(two_point.source0.tracer_args.z, harmonic_bin_1.z) + + assert_array_equal(two_point.source0.tracer_args.dndz, harmonic_bin_1.dndz) + assert_array_equal(two_point.source1.tracer_args.dndz, harmonic_bin_2.dndz)