Skip to content

Commit

Permalink
Merge branch 'master' into support-m2-mac
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpaterno authored Oct 20, 2023
2 parents d0a853e + f849d18 commit a8d9eab
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 51 deletions.
24 changes: 12 additions & 12 deletions firecrown/likelihood/gauss_family/statistic/source/number_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def __init__(self, sacc_tracer: str):
"""
super().__init__(parameter_prefix=sacc_tracer)

self.alphaz = parameters.create()
self.alphag = parameters.create()
self.z_piv = parameters.create()
self.alphaz = parameters.register_new_updatable_parameter()
self.alphag = parameters.register_new_updatable_parameter()
self.z_piv = parameters.register_new_updatable_parameter()

def apply(
self, tools: ModelingTools, tracer_arg: NumberCountsArgs
Expand Down Expand Up @@ -143,8 +143,8 @@ def __init__(self, sacc_tracer: str):
"""
super().__init__(parameter_prefix=sacc_tracer)

self.b_2 = parameters.create()
self.b_s = parameters.create()
self.b_2 = parameters.register_new_updatable_parameter()
self.b_s = parameters.register_new_updatable_parameter()

def apply(
self, tools: ModelingTools, tracer_arg: NumberCountsArgs
Expand Down Expand Up @@ -186,11 +186,11 @@ def __init__(self, sacc_tracer: str):
"""
super().__init__(parameter_prefix=sacc_tracer)

self.r_lim = parameters.create()
self.sig_c = parameters.create()
self.eta = parameters.create()
self.z_c = parameters.create()
self.z_m = parameters.create()
self.r_lim = parameters.register_new_updatable_parameter()
self.sig_c = parameters.register_new_updatable_parameter()
self.eta = parameters.register_new_updatable_parameter()
self.z_c = parameters.register_new_updatable_parameter()
self.z_m = parameters.register_new_updatable_parameter()

def apply(
self, tools: ModelingTools, tracer_arg: NumberCountsArgs
Expand Down Expand Up @@ -245,7 +245,7 @@ def __init__(self, sacc_tracer: str):
"""
super().__init__(parameter_prefix=sacc_tracer)

self.mag_bias = parameters.create()
self.mag_bias = parameters.register_new_updatable_parameter()

def apply(
self, tools: ModelingTools, tracer_arg: NumberCountsArgs
Expand Down Expand Up @@ -291,7 +291,7 @@ def __init__(
self.has_rsd = has_rsd
self.derived_scale = derived_scale

self.bias = parameters.create()
self.bias = parameters.register_new_updatable_parameter()
self.systematics = UpdatableCollection(systematics)
self.scale = scale
self.current_tracer_args: Optional[NumberCountsArgs] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def __init__(self, sacc_tracer: str):
"""
super().__init__(parameter_prefix=sacc_tracer)

self.delta_z = parameters.create()
self.delta_z = parameters.register_new_updatable_parameter()

def apply(self, tools: ModelingTools, tracer_arg: _SourceGalaxyArgsT):
"""Apply a shift to the photo-z distribution of a source."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, sacc_tracer: str) -> None:
"""
super().__init__(parameter_prefix=sacc_tracer)

self.mult_bias = parameters.create()
self.mult_bias = parameters.register_new_updatable_parameter()

def apply(
self, tools: ModelingTools, tracer_arg: WeakLensingArgs
Expand Down Expand Up @@ -124,10 +124,10 @@ def __init__(self, sacc_tracer: Optional[str] = None, alphag=1.0):
"""
super().__init__(parameter_prefix=sacc_tracer)

self.ia_bias = parameters.create()
self.alphaz = parameters.create()
self.alphag = parameters.create(alphag)
self.z_piv = parameters.create()
self.ia_bias = parameters.register_new_updatable_parameter()
self.alphaz = parameters.register_new_updatable_parameter()
self.alphag = parameters.register_new_updatable_parameter(alphag)
self.z_piv = parameters.register_new_updatable_parameter()

def apply(
self, tools: ModelingTools, tracer_arg: WeakLensingArgs
Expand Down Expand Up @@ -172,9 +172,9 @@ def __init__(self, sacc_tracer: Optional[str] = None):
as a prefix for its parameters.
"""
super().__init__(parameter_prefix=sacc_tracer)
self.ia_a_1 = parameters.create()
self.ia_a_2 = parameters.create()
self.ia_a_d = parameters.create()
self.ia_a_1 = parameters.register_new_updatable_parameter()
self.ia_a_2 = parameters.register_new_updatable_parameter()
self.ia_a_d = parameters.register_new_updatable_parameter()

def apply(
self, tools: ModelingTools, tracer_arg: WeakLensingArgs
Expand Down
2 changes: 1 addition & 1 deletion firecrown/likelihood/gauss_family/statistic/statistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def __init__(self) -> None:
# Data and theory will both be of length self.count
self.count = 3
self.data_vector: Optional[DataVector] = None
self.mean = firecrown.parameters.create()
self.mean = firecrown.parameters.register_new_updatable_parameter()
self.computed_theory_vector = False

def read(self, sacc_data: sacc.Sacc):
Expand Down
2 changes: 1 addition & 1 deletion firecrown/likelihood/gauss_family/statistic/supernova.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, sacc_tracer) -> None:
self.sacc_tracer = sacc_tracer
self.data_vector: Optional[DataVector] = None
self.a: Optional[npt.NDArray[np.float64]] = None
self.M = parameters.create()
self.M = parameters.register_new_updatable_parameter()

def read(self, sacc_data: sacc.Sacc):
"""Read the data for this statistic from the SACC file."""
Expand Down
2 changes: 1 addition & 1 deletion firecrown/likelihood/gauss_family/student_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
nu: Optional[float],
):
super().__init__(statistics)
self.nu = parameters.create(nu)
self.nu = parameters.register_new_updatable_parameter(nu)

def compute_loglike(self, tools: ModelingTools):
"""Compute the log-likelihood.
Expand Down
12 changes: 6 additions & 6 deletions firecrown/models/cluster_mass_rich_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ def __init__(
self.logMu = logMu

# Updatable parameters
self.mu_p0 = parameters.create()
self.mu_p1 = parameters.create()
self.mu_p2 = parameters.create()
self.sigma_p0 = parameters.create()
self.sigma_p1 = parameters.create()
self.sigma_p2 = parameters.create()
self.mu_p0 = parameters.register_new_updatable_parameter()
self.mu_p1 = parameters.register_new_updatable_parameter()
self.mu_p2 = parameters.register_new_updatable_parameter()
self.sigma_p0 = parameters.register_new_updatable_parameter()
self.sigma_p1 = parameters.register_new_updatable_parameter()
self.sigma_p2 = parameters.register_new_updatable_parameter()

self.logM_obs_min = 0.0
self.logM_obs_max = np.inf
Expand Down
15 changes: 15 additions & 0 deletions firecrown/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations
from typing import Iterable, List, Dict, Set, Tuple, Optional, Iterator, Sequence
import warnings
from abc import ABC, abstractmethod


Expand Down Expand Up @@ -304,6 +305,20 @@ def get_value(self) -> float:
def create(value: Optional[float] = None):
"""Create a new parameter, either a SamplerParameter or an InternalParameter.
See register_new_updatable_parameter for details."""
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
"This function is named `create` and will be removed in a future version "
"due to its name being too generic."
"Use `register_new_updatable_parameter` instead.",
category=DeprecationWarning,
)
return register_new_updatable_parameter(value)


def register_new_updatable_parameter(value: Optional[float] = None):
"""Create a new parameter, either a SamplerParameter or an InternalParameter.
If `value` is `None`, the result will be a `SamplerParameter`; Firecrown
will expect this value to be supplied by the sampling framwork. If `value`
is a `float` quantity, then Firecrown will expect this parameter to *not*
Expand Down
2 changes: 1 addition & 1 deletion tests/likelihood/lkdir/lkmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, params: NamedParameters):
parameter_prefix value and creates a sampler parameter called "sampler_param0".
"""
super().__init__(parameter_prefix=params.get_string("parameter_prefix"))
self.sampler_param0 = parameters.create()
self.sampler_param0 = parameters.register_new_updatable_parameter()

def read(self, sacc_data: sacc.Sacc) -> None:
"""This class has nothing to read."""
Expand Down
28 changes: 23 additions & 5 deletions tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from firecrown.parameters import (
DerivedParameterScalar,
DerivedParameterCollection,
register_new_updatable_parameter,
create,
InternalParameter,
SamplerParameter,
Expand All @@ -16,23 +17,40 @@
def test_create_with_no_arg():
"""Calling parameters.create() with no argument should return an
SamplerParameter"""
a_parameter = create()
assert isinstance(a_parameter, SamplerParameter)
with pytest.deprecated_call():
a_parameter = create()
assert isinstance(a_parameter, SamplerParameter)


def test_create_with_float_arg():
"""Calling parameters.create() with a float argument should return a
InternalParameter ."""
a_parameter = create(1.5)
with pytest.deprecated_call():
a_parameter = create(1.5)
assert isinstance(a_parameter, InternalParameter)
assert a_parameter.value == 1.5


def test_register_new_updatable_parameter_with_no_arg():
"""Calling parameters.create() with no argument should return an
SamplerParameter"""
a_parameter = register_new_updatable_parameter()
assert isinstance(a_parameter, SamplerParameter)


def test_register_new_updatable_parameter_with_float_arg():
"""Calling parameters.create() with a float argument should return a
InternalParameter ."""
a_parameter = register_new_updatable_parameter(1.5)
assert isinstance(a_parameter, InternalParameter)
assert a_parameter.value == 1.5


def test_create_with_wrong_arg():
def test_register_new_updatable_parameter_with_wrong_arg():
"""Calling parameters.create() with an org that is neither float nor None should
raise a TypeError."""
with pytest.raises(TypeError):
_ = create("cow") # type: ignore
_ = register_new_updatable_parameter("cow") # type: ignore


def test_required_parameters_length():
Expand Down
44 changes: 29 additions & 15 deletions tests/test_updatable.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self):
"""Initialize object with defaulted value."""
super().__init__()

self.a = parameters.create()
self.a = parameters.register_new_updatable_parameter()


class SimpleUpdatable(Updatable):
Expand All @@ -32,8 +32,8 @@ def __init__(self):
"""Initialize object with defaulted values."""
super().__init__()

self.x = parameters.create()
self.y = parameters.create()
self.x = parameters.register_new_updatable_parameter()
self.y = parameters.register_new_updatable_parameter()


class UpdatableWithDerived(Updatable):
Expand All @@ -43,8 +43,8 @@ def __init__(self):
"""Initialize object with defaulted values."""
super().__init__()

self.A = parameters.create()
self.B = parameters.create()
self.A = parameters.register_new_updatable_parameter()
self.B = parameters.register_new_updatable_parameter()

def _get_derived_parameters(self) -> DerivedParameterCollection:
derived_scale = DerivedParameterScalar("Section", "Name", self.A + self.B)
Expand Down Expand Up @@ -132,7 +132,9 @@ def test_updatable_collection_insertion():

def test_set_sampler_parameter():
my_updatable = MinimalUpdatable()
my_updatable.set_sampler_parameter("the_meaning_of_life", parameters.create())
my_updatable.set_sampler_parameter(
"the_meaning_of_life", parameters.register_new_updatable_parameter()
)

assert hasattr(my_updatable, "the_meaning_of_life")
assert my_updatable.the_meaning_of_life is None
Expand All @@ -143,21 +145,27 @@ def test_set_sampler_parameter_rejects_internal_parameter():

with pytest.raises(TypeError):
my_updatable.set_sampler_parameter(
"the_meaning_of_life", parameters.create(42.0)
"the_meaning_of_life", parameters.register_new_updatable_parameter(42.0)
)


def test_set_sampler_parameter_rejects_duplicates():
my_updatable = MinimalUpdatable()
my_updatable.set_sampler_parameter("the_meaning_of_life", parameters.create())
my_updatable.set_sampler_parameter(
"the_meaning_of_life", parameters.register_new_updatable_parameter()
)

with pytest.raises(ValueError):
my_updatable.set_sampler_parameter("the_meaning_of_life", parameters.create())
my_updatable.set_sampler_parameter(
"the_meaning_of_life", parameters.register_new_updatable_parameter()
)


def test_set_internal_parameter():
my_updatable = MinimalUpdatable()
my_updatable.set_internal_parameter("the_meaning_of_life", parameters.create(42.0))
my_updatable.set_internal_parameter(
"the_meaning_of_life", parameters.register_new_updatable_parameter(42.0)
)

assert hasattr(my_updatable, "the_meaning_of_life")
assert my_updatable.the_meaning_of_life == 42.0
Expand All @@ -166,27 +174,33 @@ def test_set_internal_parameter():
def test_set_internal_parameter_rejects_sampler_parameter():
my_updatable = MinimalUpdatable()
with pytest.raises(TypeError):
my_updatable.set_internal_parameter("sampler_param", parameters.create())
my_updatable.set_internal_parameter(
"sampler_param", parameters.register_new_updatable_parameter()
)


def test_set_internal_parameter_rejects_duplicates():
my_updatable = MinimalUpdatable()
my_updatable.set_internal_parameter("the_meaning_of_life", parameters.create(42.0))
my_updatable.set_internal_parameter(
"the_meaning_of_life", parameters.register_new_updatable_parameter(42.0)
)

with pytest.raises(ValueError):
my_updatable.set_internal_parameter(
"the_meaning_of_life", parameters.create(42.0)
"the_meaning_of_life", parameters.register_new_updatable_parameter(42.0)
)

with pytest.raises(ValueError):
my_updatable.set_internal_parameter(
"the_meaning_of_life", parameters.create(41.0)
"the_meaning_of_life", parameters.register_new_updatable_parameter(41.0)
)


def test_update_rejects_internal_parameters():
my_updatable = MinimalUpdatable()
my_updatable.set_internal_parameter("the_meaning_of_life", parameters.create(42.0))
my_updatable.set_internal_parameter(
"the_meaning_of_life", parameters.register_new_updatable_parameter(42.0)
)
assert hasattr(my_updatable, "the_meaning_of_life")

params = ParamsMap({"a": 1.1, "the_meaning_of_life": 34.0})
Expand Down

0 comments on commit a8d9eab

Please sign in to comment.