Skip to content

Commit

Permalink
Refactor twopoint (#451)
Browse files Browse the repository at this point in the history
* Add type annotations
* Remove incorrect documentation
* Remove extraneous import
* Document and test _generate_ell_or_theta
* Rename _generate_ell_or_theta to generate_bin_centers
* Move ELL_FOR_XI_DEFAULTS to generators
* Rename _ell_for_xi to log_linear_ells
* Move log_linear_ells to generators
* Move generate_bin_centers to generators
* Temporary tweaks to fix CI failures from pylint 3.3
* Move _cached_angular_distribution and make_log_interpolator to utils
* Move calculate_ells_for_interpolation to generators
* Move EllOrThetaConfig to generators
* Move generate_{ells_cells,reals} to generators
* Move apply_{ells,thetas}_min_max to generators
* Move use_source_factory and use_source_factory_metadata_index to source_factories
* Start of TwoPointTheory
* Put sources in TwoPointTheory; make it Updatable
* Move ell_for_xi_config to TwoPointTheory
* Move ell_or_theta_config into TwoPointTheory
* Move ell_or_theta_{min,max} into TwoPointTheory
* Move window to TwoPointTheory
* Remove unused TwoPoint.theory_vector
* Move sacc_tracers into TwoPointTheory
* Remove needless implementation of _update
---------
Co-authored-by: Sandro Dias Pinto Vitenti <vitenti@uel.br>
  • Loading branch information
marcpaterno authored Oct 4, 2024
1 parent 1ec56d9 commit a26cc7a
Show file tree
Hide file tree
Showing 7 changed files with 499 additions and 327 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies:
- pylint
- pytest
- pytest-cov
- python >= 3.10
- pyyaml
- requests
- sacc >= 0.11
Expand Down
177 changes: 175 additions & 2 deletions firecrown/generators/two_point.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
"""Generator support for TwoPoint statistics."""

from typing import Annotated
from pydantic import BaseModel, Field, model_validator
from __future__ import annotations

import copy
from typing import Annotated, TypedDict

from pydantic import BaseModel, Field, model_validator
import numpy as np
import numpy.typing as npt


ELL_FOR_XI_DEFAULTS = {"minimum": 2, "midpoint": 50, "maximum": 60_000, "n_log": 200}


class LogLinearElls(BaseModel):
"""Generator for log-linear integral ell values.
Expand Down Expand Up @@ -36,6 +42,11 @@ def generate(self) -> npt.NDArray[np.int64]:
The result will contain each integral value from min to mid. Starting
from mid, and going up to max, there will be n_log logarithmically
spaced values.
:param minimum: The low edge of the first bin.
:param midpoint: The high edge of the last in the linear range.
:param maximum: The high edge of the last bin.
:param n_log: The number of bins in the log section of the range.
"""
minimum, midpoint, maximum, n_log = (
self.minimum,
Expand All @@ -49,3 +60,165 @@ def generate(self) -> npt.NDArray[np.int64]:
# Round the results to the nearest integer values.
# N.B. the dtype of the result is np.dtype[float64]
return np.unique(np.around(concatenated)).astype(np.int64)


def log_linear_ells(
*, minimum: int, midpoint: int, maximum: int, n_log: int
) -> npt.NDArray[np.int64]:
"""Create an array of ells to sample the power spectrum.
This is used for for real-space predictions. The result will contain
each integral value from min to mid. Starting from mid, and going up
to max, there will be n_log logarithmically spaced values.
All values are rounded to the nearest integer.
:param minimum: The low edge of the first bin.
:param midpoint: The high edge of the last in the linear range.
:param maximum: The high edge of the last bin.
:param n_log: The number of bins in the log section of the range.
"""
return LogLinearElls(
minimum=minimum, midpoint=midpoint, maximum=maximum, n_log=n_log
).generate()


def generate_bin_centers(
*, minimum: float, maximum: float, n: int, binning: str = "log"
) -> npt.NDArray[np.float64]:
"""Return the centers of bins that span the range from minimum to maximum.
If binning is 'log', this will generate logarithmically spaced bins; if
binning is 'lin', this will generate linearly spaced bins.
:param minimum: The low edge of the first bin.
:param maximum: The high edge of the last bin.
:param n: The number of bins.
:param binning: Either 'log' or 'lin'.
:return: The centers of the bins.
"""
match binning:
case "log":
edges = np.logspace(np.log10(minimum), np.log10(maximum), n + 1)
return np.sqrt(edges[1:] * edges[:-1])
case "lin":
edges = np.linspace(minimum, maximum, n + 1)
return (edges[1:] + edges[:-1]) / 2.0
case _:
raise ValueError(f"Unrecognized binning: {binning}")


def calculate_ells_for_interpolation(
min_ell: int, max_ell: int
) -> npt.NDArray[np.int64]:
"""See log_linear_ells.
This method mixes together:
1. the default parameters in ELL_FOR_XI_DEFAULTS
2. the first and last values in w.
and then calls log_linear_ells with those arguments, returning whatever it
returns.
"""
ell_config = copy.deepcopy(ELL_FOR_XI_DEFAULTS)
ell_config["maximum"] = max_ell
ell_config["minimum"] = max(ell_config["minimum"], min_ell)
return log_linear_ells(**ell_config)


class EllOrThetaConfig(TypedDict):
"""A dictionary of options for generating the ell or theta.
This dictionary contains the minimum, maximum and number of
bins to generate the ell or theta values at which to compute the statistics.
:param minimum: The start of the binning.
:param maximum: The end of the binning.
:param n: The number of bins.
:param binning: Pass 'log' to get logarithmic spaced bins and 'lin' to get linearly
spaced bins. Default is 'log'.
"""

minimum: float
maximum: float
n: int
binning: str


def generate_ells_cells(ell_config: EllOrThetaConfig):
"""Generate ells or theta values from the configuration dictionary.
:param ell_config: the configuration parameters.
:return: ells and Cells
"""
ells = generate_bin_centers(**ell_config)
Cells = np.zeros_like(ells)

return ells, Cells


def generate_reals(theta_config: EllOrThetaConfig):
"""Generate theta and xi values from the configuration dictionary.
:param ell_config: the configuration parameters.
:return: ells and Cells
"""
thetas = generate_bin_centers(**theta_config)
xis = np.zeros_like(thetas)

return thetas, xis


def apply_ells_min_max(
ells: npt.NDArray[np.int64],
Cells: npt.NDArray[np.float64],
indices: None | npt.NDArray[np.int64],
ell_min: None | int,
ell_max: None | int,
) -> tuple[
npt.NDArray[np.int64], npt.NDArray[np.float64], None | npt.NDArray[np.int64]
]:
"""Apply the minimum and maximum ell values to the ells and Cells."""
if ell_min is not None:
locations = np.where(ells >= ell_min)
ells = ells[locations]
Cells = Cells[locations]
if indices is not None:
indices = indices[locations]

if ell_max is not None:
locations = np.where(ells <= ell_max)
ells = ells[locations]
Cells = Cells[locations]
if indices is not None:
indices = indices[locations]

return ells, Cells, indices


def apply_theta_min_max(
thetas: npt.NDArray[np.float64],
xis: npt.NDArray[np.float64],
indices: None | npt.NDArray[np.int64],
theta_min: None | float,
theta_max: None | float,
) -> tuple[
npt.NDArray[np.float64], npt.NDArray[np.float64], None | npt.NDArray[np.int64]
]:
"""Apply the minimum and maximum theta values to the thetas and xis."""
if theta_min is not None:
locations = np.where(thetas >= theta_min)
thetas = thetas[locations]
xis = xis[locations]
if indices is not None:
indices = indices[locations]

if theta_max is not None:
locations = np.where(thetas <= theta_max)
thetas = thetas[locations]
xis = xis[locations]
if indices is not None:
indices = indices[locations]

return thetas, xis, indices
61 changes: 61 additions & 0 deletions firecrown/likelihood/source_factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Factory functions for creating sources."""

from firecrown.likelihood.number_counts import NumberCountsFactory, NumberCounts
from firecrown.likelihood.weak_lensing import WeakLensingFactory, WeakLensing
from firecrown.metadata_types import InferredGalaxyZDist, Measurement, Galaxies


def use_source_factory(
inferred_galaxy_zdist: InferredGalaxyZDist,
measurement: Measurement,
wl_factory: WeakLensingFactory | None = None,
nc_factory: NumberCountsFactory | None = None,
) -> WeakLensing | NumberCounts:
"""Apply the factory to the inferred galaxy redshift distribution."""
source: WeakLensing | NumberCounts
if measurement not in inferred_galaxy_zdist.measurements:
raise ValueError(
f"Measurement {measurement} not found in inferred galaxy redshift "
f"distribution {inferred_galaxy_zdist.bin_name}!"
)

match measurement:
case Galaxies.COUNTS:
assert nc_factory is not None
source = nc_factory.create(inferred_galaxy_zdist)
case (
Galaxies.SHEAR_E
| Galaxies.SHEAR_T
| Galaxies.SHEAR_MINUS
| Galaxies.SHEAR_PLUS
):
assert wl_factory is not None
source = wl_factory.create(inferred_galaxy_zdist)
case _:
raise ValueError(f"Measurement {measurement} not supported!")
return source


def use_source_factory_metadata_index(
sacc_tracer: str,
measurement: Measurement,
wl_factory: WeakLensingFactory | None = None,
nc_factory: NumberCountsFactory | None = None,
) -> WeakLensing | NumberCounts:
"""Apply the factory to create a source from metadata only."""
source: WeakLensing | NumberCounts
match measurement:
case Galaxies.COUNTS:
assert nc_factory is not None
source = nc_factory.create_from_metadata_only(sacc_tracer)
case (
Galaxies.SHEAR_E
| Galaxies.SHEAR_T
| Galaxies.SHEAR_MINUS
| Galaxies.SHEAR_PLUS
):
assert wl_factory is not None
source = wl_factory.create_from_metadata_only(sacc_tracer)
case _:
raise ValueError(f"Measurement {measurement} not supported!")
return source
Loading

0 comments on commit a26cc7a

Please sign in to comment.