diff --git a/firecrown/likelihood/gauss_family/statistic/source/number_counts.py b/firecrown/likelihood/gauss_family/statistic/source/number_counts.py index 6bd010301..daf0b017a 100644 --- a/firecrown/likelihood/gauss_family/statistic/source/number_counts.py +++ b/firecrown/likelihood/gauss_family/statistic/source/number_counts.py @@ -3,16 +3,22 @@ """ from __future__ import annotations -from typing import List, Tuple, Optional, final +from typing import List, Union, Tuple, Optional, final from dataclasses import dataclass, replace from abc import abstractmethod import numpy as np import numpy.typing as npt import pyccl -from scipy.interpolate import Akima1DInterpolator -from .source import Source, Tracer, SourceSystematic +from .source import ( + Tracer, + SourceGalaxy, + SourceGalaxyArgs, + SourceGalaxySystematic, + SourceGalaxyPhotoZShift, +) + from ..... import parameters from .....modeling_tools import ModelingTools @@ -27,12 +33,10 @@ @dataclass(frozen=True) -class NumberCountsArgs: +class NumberCountsArgs(SourceGalaxyArgs): """Class for number counts tracer builder argument.""" scale: float - z: npt.NDArray[np.float64] - dndz: npt.NDArray[np.float64] bias: Optional[npt.NDArray[np.float64]] = None mag_bias: Optional[Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]] = None has_pt: bool = False @@ -41,7 +45,7 @@ class NumberCountsArgs: b_s: Optional[Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]] = None -class NumberCountsSystematic(SourceSystematic): +class NumberCountsSystematic(SourceGalaxySystematic[NumberCountsArgs]): """Abstract base class for systematics for Number Counts sources. Derived classes must implement :python`apply` with the correct signature.""" @@ -53,6 +57,10 @@ def apply( """Apply method to include systematics in the tracer_arg.""" +class PhotoZShift(SourceGalaxyPhotoZShift[NumberCountsArgs]): + """Photo-z shift systematic.""" + + class LinearBiasSystematic(NumberCountsSystematic): """Linear bias systematic. @@ -248,43 +256,7 @@ def apply( ) -class PhotoZShift(NumberCountsSystematic): - """A photo-z shift bias. - - This systematic shifts the photo-z distribution by some ammount `delta_z`. - - The following parameters are special Updatable parameters, which means that - they can be updated by the sampler, sacc_tracer is going to be used as a - prefix for the parameters: - - :ivar delta_z: the photo-z shift. - """ - - def __init__(self, sacc_tracer: str): - """Create a PhotoZShift object, using the specified tracer name. - - :param sacc_tracer: the name of the tracer in the SACC file. This is used - as a prefix for its parameters. - """ - super().__init__(parameter_prefix=sacc_tracer) - - self.delta_z = parameters.create() - - def apply(self, tools: ModelingTools, tracer_arg: NumberCountsArgs): - """Apply a shift to the photo-z distribution of a source.""" - - dndz_interp = Akima1DInterpolator(tracer_arg.z, tracer_arg.dndz) - - dndz = dndz_interp(tracer_arg.z - self.delta_z, extrapolate=False) - dndz[np.isnan(dndz)] = 0.0 - - return replace( - tracer_arg, - dndz=dndz, - ) - - -class NumberCounts(Source): +class NumberCounts(SourceGalaxy[NumberCountsArgs]): """Source class for number counts.""" systematics: UpdatableCollection @@ -297,7 +269,11 @@ def __init__( has_rsd: bool = False, derived_scale: bool = False, scale: float = 1.0, - systematics: Optional[List[NumberCountsSystematic]] = None, + systematics: Optional[ + List[ + Union[NumberCountsSystematic, SourceGalaxySystematic[NumberCountsArgs]] + ] + ] = None, ): """Initialize the NumberCounts object. @@ -309,7 +285,7 @@ def __init__( :param scale: the initial scale of the tracer. :param systematics: a list of systematics to apply to the tracer. """ - super().__init__(sacc_tracer) + super().__init__(sacc_tracer=sacc_tracer, systematics=systematics) self.sacc_tracer = sacc_tracer self.has_rsd = has_rsd @@ -350,16 +326,15 @@ def _read(self, sacc_data): sacc_data : sacc.Sacc The data in the sacc format. """ - tracer = sacc_data.get_tracer(self.sacc_tracer) - z = getattr(tracer, "z").copy().flatten() - nz = getattr(tracer, "nz").copy().flatten() - indices = np.argsort(z) - z = z[indices] - nz = nz[indices] self.tracer_args = NumberCountsArgs( - scale=self.scale, z=z, dndz=nz, bias=None, mag_bias=None + scale=self.scale, + z=np.array([]), + dndz=np.array([]), + bias=None, + mag_bias=None, ) + super()._read(sacc_data) def create_tracers(self, tools: ModelingTools): tracer_args = self.tracer_args diff --git a/firecrown/likelihood/gauss_family/statistic/source/source.py b/firecrown/likelihood/gauss_family/statistic/source/source.py index 10fc81b66..cc53bf984 100644 --- a/firecrown/likelihood/gauss_family/statistic/source/source.py +++ b/firecrown/likelihood/gauss_family/statistic/source/source.py @@ -1,10 +1,14 @@ -"""Abstract base classes for GaussianFamily statistics. - +"""Abstract base classes for TwoPoint Statistics sources. """ from __future__ import annotations -from typing import Optional, Sequence, final +from typing import Optional, List, Sequence, final, TypeVar, Generic from abc import abstractmethod +from dataclasses import dataclass, replace +import numpy as np +import numpy.typing as npt +from scipy.interpolate import Akima1DInterpolator + import sacc import pyccl @@ -12,7 +16,8 @@ from .....modeling_tools import ModelingTools from .....parameters import ParamsMap -from .....updatable import Updatable +from ..... import parameters +from .....updatable import Updatable, UpdatableCollection class SourceSystematic(Updatable): @@ -159,3 +164,117 @@ def has_pt(self) -> bool: def has_hm(self) -> bool: """Return True if we have a halo_profile, and False if not.""" return self.halo_profile is not None + + +# Sources of galaxy distributions + + +@dataclass(frozen=True) +class SourceGalaxyArgs: + """Class for galaxy based sources arguments.""" + + z: npt.NDArray[np.float64] + dndz: npt.NDArray[np.float64] + + +_SourceGalaxyArgsT = TypeVar("_SourceGalaxyArgsT", bound=SourceGalaxyArgs) + + +class SourceGalaxySystematic(SourceSystematic, Generic[_SourceGalaxyArgsT]): + """Abstract base class for all galaxy based source systematics.""" + + @abstractmethod + def apply( + self, tools: ModelingTools, tracer_arg: _SourceGalaxyArgsT + ) -> _SourceGalaxyArgsT: + """Apply method to include systematics in the tracer_arg.""" + + +_SourceGalaxySystematicT = TypeVar( + "_SourceGalaxySystematicT", bound=SourceGalaxySystematic +) + + +class SourceGalaxyPhotoZShift( + SourceGalaxySystematic[_SourceGalaxyArgsT], Generic[_SourceGalaxyArgsT] +): + """A photo-z shift bias. + + This systematic shifts the photo-z distribution by some amount `delta_z`. + + The following parameters are special Updatable parameters, which means that + they can be updated by the sampler, sacc_tracer is going to be used as a + prefix for the parameters: + + :ivar delta_z: the photo-z shift. + """ + + def __init__(self, sacc_tracer: str): + """Create a PhotoZShift object, using the specified tracer name. + + :param sacc_tracer: the name of the tracer in the SACC file. This is used + as a prefix for its parameters. + """ + super().__init__(parameter_prefix=sacc_tracer) + + self.delta_z = parameters.create() + + def apply(self, tools: ModelingTools, tracer_arg: _SourceGalaxyArgsT): + """Apply a shift to the photo-z distribution of a source.""" + + dndz_interp = Akima1DInterpolator(tracer_arg.z, tracer_arg.dndz) + + dndz = dndz_interp(tracer_arg.z - self.delta_z, extrapolate=False) + dndz[np.isnan(dndz)] = 0.0 + + return replace( + tracer_arg, + dndz=dndz, + ) + + +class SourceGalaxy(Source, Generic[_SourceGalaxyArgsT]): + """Source class for galaxy based sources.""" + + def __init__( + self, + *, + sacc_tracer: str, + systematics: Optional[List[SourceGalaxySystematic]] = None, + ): + """Initialize the SourceGalaxy object. + + :param sacc_tracer: the name of the tracer in the SACC file. This is used + as a prefix for its parameters. + + """ + super().__init__(sacc_tracer) + + self.sacc_tracer = sacc_tracer + self.current_tracer_args: Optional[_SourceGalaxyArgsT] = None + self.systematics: UpdatableCollection = UpdatableCollection(systematics) + self.tracer_args: _SourceGalaxyArgsT + + def _read(self, sacc_data: sacc.Sacc): + """Read the galaxy redshift distribution model from a sacc file. + All derived classes must call this method in their own `_read` method + after they have read their own data and initialized their tracer_args.""" + + tracer = sacc_data.get_tracer(self.sacc_tracer) + + z = getattr(tracer, "z").copy().flatten() + nz = getattr(tracer, "nz").copy().flatten() + indices = np.argsort(z) + z = z[indices] + nz = nz[indices] + + if self.tracer_args is None: + raise RuntimeError( + "Must initialize tracer_args before calling _read on SourceGalaxy" + ) + + self.tracer_args = replace( + self.tracer_args, + z=z, + dndz=nz, + ) diff --git a/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py b/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py index 305bbc7ad..40a795671 100644 --- a/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py +++ b/firecrown/likelihood/gauss_family/statistic/source/weak_lensing.py @@ -3,7 +3,7 @@ """ from __future__ import annotations -from typing import List, Tuple, Optional, final +from typing import List, Tuple, Optional, Union, final from dataclasses import dataclass, replace from abc import abstractmethod @@ -12,26 +12,28 @@ import pyccl import pyccl.nl_pt import sacc -from scipy.interpolate import Akima1DInterpolator -from .source import Source, Tracer, SourceSystematic +from .source import ( + SourceGalaxy, + Tracer, + SourceGalaxyArgs, + SourceGalaxySystematic, + SourceGalaxyPhotoZShift, +) from ..... import parameters from .....parameters import ( ParamsMap, ) from .....modeling_tools import ModelingTools -from .....updatable import UpdatableCollection __all__ = ["WeakLensing"] @dataclass(frozen=True) -class WeakLensingArgs: +class WeakLensingArgs(SourceGalaxyArgs): """Class for weak lensing tracer builder argument.""" scale: float - z: npt.NDArray[np.float64] - dndz: npt.NDArray[np.float64] ia_bias: Optional[Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]] has_pt: bool = False @@ -42,7 +44,7 @@ class WeakLensingArgs: ia_pt_c_2: Optional[Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]] = None -class WeakLensingSystematic(SourceSystematic): +class WeakLensingSystematic(SourceGalaxySystematic[WeakLensingArgs]): """Abstract base class for all weak lensing systematics.""" @abstractmethod @@ -52,6 +54,10 @@ def apply( """Apply method to include systematics in the tracer_arg.""" +class PhotoZShift(SourceGalaxyPhotoZShift[WeakLensingArgs]): + """Photo-z shift systematic.""" + + class MultiplicativeShearBias(WeakLensingSystematic): """Multiplicative shear bias systematic. @@ -196,54 +202,17 @@ def apply( ) -class PhotoZShift(WeakLensingSystematic): - """A photo-z shift bias. - - This systematic shifts the photo-z distribution by some amount `delta_z`. - - The following parameters are special Updatable parameters, which means that - they can be updated by the sampler, sacc_tracer is going to be used as a - prefix for the parameters: - - :ivar delta_z: the photo-z shift. - """ - - def __init__(self, sacc_tracer: str): - """Create a PhotoZShift object, using the specified tracer name. - - :param sacc_tracer: the name of the tracer in the SACC file. This is used - as a prefix for its parameters. - """ - super().__init__(parameter_prefix=sacc_tracer) - - self.delta_z = parameters.create() - - def apply(self, tools: ModelingTools, tracer_arg: WeakLensingArgs): - """Apply a shift to the photo-z distribution of a source.""" - - dndz_interp = Akima1DInterpolator(tracer_arg.z, tracer_arg.dndz) - - dndz = dndz_interp(tracer_arg.z - self.delta_z, extrapolate=False) - dndz[np.isnan(dndz)] = 0.0 - - return replace( - tracer_arg, - dndz=dndz, - ) - - -class WeakLensing(Source): +class WeakLensing(SourceGalaxy[WeakLensingArgs]): """Source class for weak lensing.""" - systematics: UpdatableCollection - tracer_args: WeakLensingArgs - def __init__( self, *, sacc_tracer: str, scale: float = 1.0, - systematics: Optional[List[WeakLensingSystematic]] = None, + systematics: Optional[ + List[Union[WeakLensingSystematic, SourceGalaxySystematic[WeakLensingArgs]]] + ] = None, ): """Initialize the WeakLensing object. @@ -255,12 +224,12 @@ def __init__( this source. """ - super().__init__(sacc_tracer) + super().__init__(sacc_tracer=sacc_tracer, systematics=systematics) self.sacc_tracer = sacc_tracer self.scale = scale self.current_tracer_args: Optional[WeakLensingArgs] = None - self.systematics = UpdatableCollection(systematics) + self.tracer_args: WeakLensingArgs @final def _update_source(self, params: ParamsMap): @@ -275,15 +244,11 @@ def _read(self, sacc_data: sacc.Sacc) -> None: This sets self.tracer_args, based on the data in `sacc_data` associated with this object's `sacc_tracer` name. """ - tracer = sacc_data.get_tracer(self.sacc_tracer) - - z = getattr(tracer, "z").copy().flatten() - nz = getattr(tracer, "nz").copy().flatten() - indices = np.argsort(z) - z = z[indices] - nz = nz[indices] + self.tracer_args = WeakLensingArgs( + scale=self.scale, z=np.array([]), dndz=np.array([]), ia_bias=None + ) - self.tracer_args = WeakLensingArgs(scale=self.scale, z=z, dndz=nz, ia_bias=None) + super()._read(sacc_data) def create_tracers(self, tools: ModelingTools): """