Skip to content

Commit

Permalink
Dataclass generator serialization (#426)
Browse files Browse the repository at this point in the history
* Organizing warnings and filters for tests.
* Add missing default values.
* Renaming MeasuredType to Measurement.
* Adding pydantic and working on serialization.
* InferredZDist tutorial draft.
* Improving tutorials cross ref.
* New inferred_zdist_serialization tutorial.
* Adding support for serializing factories.
* Updated global factory use.
* Updated tutorial to show factory serialization.
* Remove obsolete parameter to fix ComoSIS complaint
* Lock scipy version because update breaks CCL
* Change spelling: 1d -> 2D
* Introduce LogLinearElls generator
* Tweaks to tutorials
* Improve and test error messages
* Add type annotation
* Add test of global TattAlignmentSystematic
* Test the failure modes of base_model_from_yaml
---------
Co-authored-by: Marc Paterno <paterno@fnal.gov>
  • Loading branch information
vitenti authored Jun 26, 2024
1 parent 21b67d6 commit 71504b5
Show file tree
Hide file tree
Showing 34 changed files with 1,724 additions and 537 deletions.
1 change: 1 addition & 0 deletions docs/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies:
- cobaya
- pygobject-stubs
- pyccl>=2.8.0
- pydantic
- pytest
- quarto
- sacc>=0.11
Expand Down
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ dependencies:
- portalocker
- pybobyqa
- pyccl >= 2.8.0
- pydantic
- pylint
- pytest
- pytest-cov
- pyyaml
- requests
- sacc >= 0.11
- scipy
- scipy < 1.14.0
- sphinx=7.1.2
- types-pyyaml
- urllib3
1 change: 0 additions & 1 deletion examples/des_y1_3x2pt/des_y1_3x2pt_PT.ini
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ feedback = 0
zmin = 0.0
zmax = 4.0
nz = 100
kmin = 1e-4
kmax = 50.0
nk = 1000

Expand Down
4 changes: 0 additions & 4 deletions firecrown/connector/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __init__(self, *, require_nonlinear_pk: bool = False):

def get_params_names(self) -> list[str]:
"""Return the names of the expected cosmological parameters for this mapping."""
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
"This method is implementation specific and should only be "
"implemented on the appropriated subclasses. This method"
Expand All @@ -78,7 +77,6 @@ def get_params_names(self) -> list[str]:
def transform_k_h_to_k(self, k_h):
"""Transform the given k_h (k over h) to k."""
assert k_h is not None # use assertion to silence pylint warning
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
"This method is implementation specific and should only be "
"implemented on the appropriated subclasses. This method"
Expand All @@ -89,7 +87,6 @@ def transform_k_h_to_k(self, k_h):
def transform_p_k_h3_to_p_k(self, p_k_h3):
r"""Transform the given :math:`p_k h^3 \to p_k`."""
assert p_k_h3 is not None # use assertion to silence pylint warning
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
"This method is implementation specific and should only be "
"implemented on the appropriated subclasses. This method"
Expand All @@ -100,7 +97,6 @@ def transform_p_k_h3_to_p_k(self, p_k_h3):
def transform_h_to_h_over_h0(self, h):
"""Transform distances h to :math:`h/h_0`."""
assert h is not None # use assertion to silence pylint warning
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
"This method is implementation specific and should only be "
"implemented on the appropriated subclasses. This method"
Expand Down
213 changes: 206 additions & 7 deletions firecrown/generators/inferred_galaxy_zdist.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,40 @@
"""Generation of inferred galaxy redshift distributions."""

from typing import TypedDict, Annotated, Any
from itertools import pairwise

from pydantic import BaseModel, ConfigDict, Field, field_serializer, BeforeValidator

import numpy as np
import numpy.typing as npt
from scipy.special import gamma, erf, erfc
from scipy.integrate import quad

from numcosmo_py import Ncm

from firecrown.metadata.two_point import InferredGalaxyZDist, MeasuredType
from firecrown.metadata.two_point import (
InferredGalaxyZDist,
Measurement,
Galaxies,
CMB,
Clusters,
ALL_MEASUREMENT_TYPES,
make_measurement_dict,
)

BinsType = TypedDict("BinsType", {"edges": npt.NDArray, "sigma_z": float})

Y1_ALPHA = 0.94
Y1_BETA = 2.0
Y1_Z0 = 0.26
Y1_LENS_BINS = {"edges": np.linspace(0.2, 1.2, 5 + 1), "sigma_z": 0.03}
Y1_SOURCE_BINS = {"edges": np.linspace(0.2, 1.2, 5 + 1), "sigma_z": 0.05}
Y1_LENS_BINS: BinsType = {"edges": np.linspace(0.2, 1.2, 5 + 1), "sigma_z": 0.03}
Y1_SOURCE_BINS: BinsType = {"edges": np.linspace(0.2, 1.2, 5 + 1), "sigma_z": 0.05}

Y10_ALPHA = 0.90
Y10_BETA = 2.0
Y10_Z0 = 0.28
Y10_LENS_BINS = {"edges": np.linspace(0.2, 1.2, 10 + 1), "sigma_z": 0.03}
Y10_SOURCE_BINS = {"edges": np.linspace(0.2, 1.2, 10 + 1), "sigma_z": 0.05}
Y10_LENS_BINS: BinsType = {"edges": np.linspace(0.2, 1.2, 10 + 1), "sigma_z": 0.03}
Y10_SOURCE_BINS: BinsType = {"edges": np.linspace(0.2, 1.2, 10 + 1), "sigma_z": 0.05}


class ZDistLSSTSRD:
Expand Down Expand Up @@ -128,7 +143,7 @@ def binned_distribution(
sigma_z: float,
z: npt.NDArray,
name: str,
measured_type: MeasuredType,
measurement: Measurement,
use_autoknot: bool = False,
autoknots_reltol: float = 1.0e-4,
autoknots_abstol: float = 1.0e-15,
Expand Down Expand Up @@ -169,5 +184,189 @@ def _P(z, _):
)

return InferredGalaxyZDist(
bin_name=name, z=z_knots, dndz=dndz, measured_type=measured_type
bin_name=name, z=z_knots, dndz=dndz, measurement=measurement
)


class LinearGrid1D(BaseModel):
"""A 1D linear grid."""

model_config = ConfigDict(extra="forbid", frozen=True)

start: float
end: float
num: int

def generate(self) -> npt.NDArray:
"""Generate the 1D linear grid."""
return np.linspace(self.start, self.end, self.num)


class RawGrid1D(BaseModel):
"""A 1D grid."""

model_config = ConfigDict(extra="forbid", frozen=True)

values: list[float]

def generate(self) -> npt.NDArray:
"""Generate the 1D grid."""
return np.array(self.values)


Grid1D = LinearGrid1D | RawGrid1D


def make_measurement(value: Measurement | dict[str, Any]) -> Measurement:
"""Create a Measurement object from a dictionary."""
if isinstance(value, ALL_MEASUREMENT_TYPES):
return value

if not isinstance(value, dict):
raise ValueError(f"Invalid Measurement: {value} is not a dictionary")

if "subject" not in value:
raise ValueError("Invalid Measurement: dictionary does not contain 'subject'")

subject = value["subject"]

match subject:
case "Galaxies":
return Galaxies[value["property"]]
case "CMB":
return CMB[value["property"]]
case "Clusters":
return Clusters[value["property"]]
case _:
raise ValueError(
f"Invalid Measurement: subject: '{subject}' is not recognized"
)


class ZDistLSSTSRDBin(BaseModel):
"""LSST Inferred galaxy redshift distributions in bins."""

model_config = ConfigDict(extra="forbid")

zpl: float
zpu: float
sigma_z: float
z: Annotated[Grid1D, Field(union_mode="left_to_right")]
bin_name: str
measurement: Annotated[Measurement, BeforeValidator(make_measurement)]
use_autoknot: bool = False
autoknots_reltol: float = 1.0e-4
autoknots_abstol: float = 1.0e-15

@field_serializer("measurement")
@classmethod
def serialize_measurement(cls, value: Measurement) -> dict:
"""Serialize the Measurement."""
return make_measurement_dict(value)

def generate(self, zdist: ZDistLSSTSRD) -> InferredGalaxyZDist:
"""Generate the inferred galaxy redshift distribution in bins."""
return zdist.binned_distribution(
zpl=self.zpl,
zpu=self.zpu,
sigma_z=self.sigma_z,
z=self.z.generate(),
name=self.bin_name,
measurement=self.measurement,
use_autoknot=self.use_autoknot,
autoknots_reltol=self.autoknots_reltol,
autoknots_abstol=self.autoknots_abstol,
)


class ZDistLSSTSRDBinCollection(BaseModel):
"""LSST Inferred galaxy redshift distributions in bins."""

model_config = ConfigDict(extra="forbid", frozen=True)

alpha: float
beta: float
z0: float
bins: list[ZDistLSSTSRDBin]

def generate(self) -> list[InferredGalaxyZDist]:
"""Generate the inferred galaxy redshift distributions in bins."""
zdist = ZDistLSSTSRD(alpha=self.alpha, beta=self.beta, z0=self.z0)
return [bin.generate(zdist) for bin in self.bins]


LSST_Y1_LENS_BIN_COLLECTION = ZDistLSSTSRDBinCollection(
alpha=Y1_ALPHA,
beta=Y1_BETA,
z0=Y1_Z0,
bins=[
ZDistLSSTSRDBin(
zpl=zpl,
zpu=zpu,
sigma_z=Y1_LENS_BINS["sigma_z"],
z=RawGrid1D(values=[0.0, 3.0]),
bin_name=f"lens_{zpl:.1f}_{zpu:.1f}_y1",
measurement=Galaxies.COUNTS,
use_autoknot=True,
autoknots_reltol=1.0e-5,
)
for zpl, zpu in pairwise(Y1_LENS_BINS["edges"])
],
)

LSST_Y1_SOURCE_BIN_COLLECTION = ZDistLSSTSRDBinCollection(
alpha=Y1_ALPHA,
beta=Y1_BETA,
z0=Y1_Z0,
bins=[
ZDistLSSTSRDBin(
zpl=zpl,
zpu=zpu,
sigma_z=Y1_SOURCE_BINS["sigma_z"],
z=RawGrid1D(values=[0.0, 3.0]),
bin_name=f"source_{zpl:.1f}_{zpu:.1f}_y1",
measurement=Galaxies.SHEAR_E,
use_autoknot=True,
autoknots_reltol=1.0e-5,
)
for zpl, zpu in pairwise(Y1_SOURCE_BINS["edges"])
],
)

LSST_Y10_LENS_BIN_COLLECTION = ZDistLSSTSRDBinCollection(
alpha=Y10_ALPHA,
beta=Y10_BETA,
z0=Y10_Z0,
bins=[
ZDistLSSTSRDBin(
zpl=zpl,
zpu=zpu,
sigma_z=Y10_LENS_BINS["sigma_z"],
z=RawGrid1D(values=[0.0, 3.0]),
bin_name=f"lens_{zpl:.1f}_{zpu:.1f}_y10",
measurement=Galaxies.COUNTS,
use_autoknot=True,
autoknots_reltol=1.0e-5,
)
for zpl, zpu in pairwise(Y10_LENS_BINS["edges"])
],
)

LSSST_Y10_SOURCE_BIN_COLLECTION = ZDistLSSTSRDBinCollection(
alpha=Y10_ALPHA,
beta=Y10_BETA,
z0=Y10_Z0,
bins=[
ZDistLSSTSRDBin(
zpl=zpl,
zpu=zpu,
sigma_z=Y10_SOURCE_BINS["sigma_z"],
z=RawGrid1D(values=[0.0, 3.0]),
bin_name=f"source_{zpl:.1f}_{zpu:.1f}_y10",
measurement=Galaxies.SHEAR_E,
use_autoknot=True,
autoknots_reltol=1.0e-5,
)
for zpl, zpu in pairwise(Y10_SOURCE_BINS["edges"])
],
)
51 changes: 51 additions & 0 deletions firecrown/generators/two_point.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Generator support for TwoPoint statistics."""

from typing import Annotated
from pydantic import BaseModel, Field, model_validator

import numpy as np
import numpy.typing as npt


class LogLinearElls(BaseModel):
"""Generator for log-linear integral ell values.
Not all ell values will be generated. The result will contain each integral
value from min to mid. Starting from mid, and going up to max, there will be
n_log logarithmically spaced values.
Note that midpoint must be strictly greater than minimum, and strictly less
than maximum. n_log must be positive.
"""

minimum: Annotated[int, Field(ge=0)]
midpoint: Annotated[int, Field(ge=0)]
maximum: Annotated[int, Field(ge=0)]
n_log: Annotated[int, Field(ge=1)]

@model_validator(mode="after")
def require_increasing(self) -> "LogLinearElls":
"""Validate the ell values."""
assert self.minimum < self.midpoint
assert self.midpoint < self.maximum
return self

def generate(self) -> npt.NDArray[np.int64]:
"""Generate the log-linear ell values.
The result will contain each integral value from min to mid. Starting
from mid, and going up to max, there will be n_log logarithmically
spaced values.
"""
minimum, midpoint, maximum, n_log = (
self.minimum,
self.midpoint,
self.maximum,
self.n_log,
)
lower_range = np.linspace(minimum, midpoint - 1, midpoint - minimum)
upper_range = np.logspace(np.log10(midpoint), np.log10(maximum), n_log)
concatenated = np.concatenate((lower_range, upper_range))
# Round the results to the nearest integer values.
# N.B. the dtype of the result is np.dtype[float64]
return np.unique(np.around(concatenated)).astype(np.int64)
1 change: 0 additions & 1 deletion firecrown/likelihood/gaussfamily.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ def compute(
self, tools: ModelingTools
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
"""Calculate and return both the data and theory vectors."""
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
"The use of the `compute` method on Statistic is deprecated."
"The Statistic objects should implement `get_data` and "
Expand Down
1 change: 0 additions & 1 deletion firecrown/likelihood/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ def load_likelihood_from_module_type(
f"{module.__file__} does not define "
f"a `build_likelihood` factory function."
)
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
"The use of a likelihood variable in Firecrown's initialization "
"module is deprecated. Any parameters passed to the likelihood "
Expand Down
Loading

0 comments on commit 71504b5

Please sign in to comment.