Skip to content

Commit

Permalink
Move use_source_factory and use_source_factory_metadata_index to sour…
Browse files Browse the repository at this point in the history
…ce_factories
  • Loading branch information
marcpaterno committed Sep 30, 2024
1 parent f96a8ec commit af87f6b
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 64 deletions.
61 changes: 61 additions & 0 deletions firecrown/likelihood/source_factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Factory functions for creating sources."""

from firecrown.likelihood.number_counts import NumberCountsFactory, NumberCounts
from firecrown.likelihood.weak_lensing import WeakLensingFactory, WeakLensing
from firecrown.metadata_types import InferredGalaxyZDist, Measurement, Galaxies


def use_source_factory(
inferred_galaxy_zdist: InferredGalaxyZDist,
measurement: Measurement,
wl_factory: WeakLensingFactory | None = None,
nc_factory: NumberCountsFactory | None = None,
) -> WeakLensing | NumberCounts:
"""Apply the factory to the inferred galaxy redshift distribution."""
source: WeakLensing | NumberCounts
if measurement not in inferred_galaxy_zdist.measurements:
raise ValueError(
f"Measurement {measurement} not found in inferred galaxy redshift "
f"distribution {inferred_galaxy_zdist.bin_name}!"
)

match measurement:
case Galaxies.COUNTS:
assert nc_factory is not None
source = nc_factory.create(inferred_galaxy_zdist)
case (
Galaxies.SHEAR_E
| Galaxies.SHEAR_T
| Galaxies.SHEAR_MINUS
| Galaxies.SHEAR_PLUS
):
assert wl_factory is not None
source = wl_factory.create(inferred_galaxy_zdist)
case _:
raise ValueError(f"Measurement {measurement} not supported!")
return source


def use_source_factory_metadata_index(
sacc_tracer: str,
measurement: Measurement,
wl_factory: WeakLensingFactory | None = None,
nc_factory: NumberCountsFactory | None = None,
) -> WeakLensing | NumberCounts:
"""Apply the factory to create a source from metadata only."""
source: WeakLensing | NumberCounts
match measurement:
case Galaxies.COUNTS:
assert nc_factory is not None
source = nc_factory.create_from_metadata_only(sacc_tracer)
case (
Galaxies.SHEAR_E
| Galaxies.SHEAR_T
| Galaxies.SHEAR_MINUS
| Galaxies.SHEAR_PLUS
):
assert wl_factory is not None
source = wl_factory.create_from_metadata_only(sacc_tracer)
case _:
raise ValueError(f"Measurement {measurement} not supported!")
return source
65 changes: 4 additions & 61 deletions firecrown/likelihood/two_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,22 @@
apply_theta_min_max,
)
from firecrown.likelihood.source import Source, Tracer
from firecrown.likelihood.source_factories import (
use_source_factory,
use_source_factory_metadata_index,
)
from firecrown.likelihood.weak_lensing import (
WeakLensingFactory,
WeakLensing,
)
from firecrown.likelihood.number_counts import (
NumberCountsFactory,
NumberCounts,
)
from firecrown.likelihood.statistic import (
DataVector,
Statistic,
TheoryVector,
)
from firecrown.metadata_types import (
Galaxies,
InferredGalaxyZDist,
Measurement,
TRACER_NAMES_TOTAL,
TracerNames,
TwoPointHarmonic,
Expand Down Expand Up @@ -74,62 +73,6 @@
}


def use_source_factory(
inferred_galaxy_zdist: InferredGalaxyZDist,
measurement: Measurement,
wl_factory: WeakLensingFactory | None = None,
nc_factory: NumberCountsFactory | None = None,
) -> WeakLensing | NumberCounts:
"""Apply the factory to the inferred galaxy redshift distribution."""
source: WeakLensing | NumberCounts
if measurement not in inferred_galaxy_zdist.measurements:
raise ValueError(
f"Measurement {measurement} not found in inferred galaxy redshift "
f"distribution {inferred_galaxy_zdist.bin_name}!"
)

match measurement:
case Galaxies.COUNTS:
assert nc_factory is not None
source = nc_factory.create(inferred_galaxy_zdist)
case (
Galaxies.SHEAR_E
| Galaxies.SHEAR_T
| Galaxies.SHEAR_MINUS
| Galaxies.SHEAR_PLUS
):
assert wl_factory is not None
source = wl_factory.create(inferred_galaxy_zdist)
case _:
raise ValueError(f"Measurement {measurement} not supported!")
return source


def use_source_factory_metadata_index(
sacc_tracer: str,
measurement: Measurement,
wl_factory: WeakLensingFactory | None = None,
nc_factory: NumberCountsFactory | None = None,
) -> WeakLensing | NumberCounts:
"""Apply the factory to create a source from metadata only."""
source: WeakLensing | NumberCounts
match measurement:
case Galaxies.COUNTS:
assert nc_factory is not None
source = nc_factory.create_from_metadata_only(sacc_tracer)
case (
Galaxies.SHEAR_E
| Galaxies.SHEAR_T
| Galaxies.SHEAR_MINUS
| Galaxies.SHEAR_PLUS
):
assert wl_factory is not None
source = wl_factory.create_from_metadata_only(sacc_tracer)
case _:
raise ValueError(f"Measurement {measurement} not supported!")
return source


class TwoPoint(Statistic):
"""A statistic that represents the correlation between two measurements.
Expand Down
6 changes: 4 additions & 2 deletions tests/likelihood/gauss_family/statistic/test_two_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
TwoPoint,
TracerNames,
TRACER_NAMES_TOTAL,
use_source_factory,
use_source_factory_metadata_index,
WeakLensingFactory,
NumberCountsFactory,
)
from firecrown.likelihood.source_factories import (
use_source_factory,
use_source_factory_metadata_index,
)
from firecrown.generators.two_point import (
log_linear_ells,
generate_bin_centers,
Expand Down
3 changes: 2 additions & 1 deletion tests/metadata/test_metadata_two_point_sacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
extract_all_harmonic_data,
extract_all_real_data,
)
from firecrown.likelihood.two_point import TwoPoint, use_source_factory
from firecrown.likelihood.two_point import TwoPoint
from firecrown.likelihood.source_factories import use_source_factory


@pytest.fixture(name="sacc_galaxy_src0_src0_invalid_data_type")
Expand Down

0 comments on commit af87f6b

Please sign in to comment.