Skip to content

Commit

Permalink
373 derivedparameter base class is not useful (#374)
Browse files Browse the repository at this point in the history
* Remove DerivedParameter base class
* Rename DerivedParameterScalar -> DerivedParameter
* Bad comparison should raise a TypeError
* Improve test coverage
  • Loading branch information
marcpaterno authored Feb 15, 2024
1 parent 481a6bc commit 848b9c1
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .....modeling_tools import ModelingTools
from .....parameters import (
ParamsMap,
DerivedParameterScalar,
DerivedParameter,
DerivedParameterCollection,
)
from .....updatable import UpdatableCollection
Expand Down Expand Up @@ -307,7 +307,7 @@ def _update_source(self, params: ParamsMap):
def _get_derived_parameters(self) -> DerivedParameterCollection:
if self.derived_scale:
assert self.current_tracer_args is not None
derived_scale = DerivedParameterScalar(
derived_scale = DerivedParameter(
"TwoPoint",
f"NumberCountsScale_{self.sacc_tracer}",
self.current_tracer_args.scale,
Expand Down
58 changes: 22 additions & 36 deletions firecrown/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import annotations
from typing import Iterable, List, Dict, Set, Tuple, Optional, Iterator, Sequence
import warnings
from abc import ABC, abstractmethod


def parameter_get_full_name(prefix: Optional[str], param: str) -> str:
Expand Down Expand Up @@ -108,7 +107,10 @@ def __eq__(self, other: object):
are equal (including appearing in the same order).
"""
if not isinstance(other, RequiredParameters):
return NotImplemented
n = type(other).__name__
raise TypeError(
f"Cannot compare a RequiredParameter to an object of type {n}"
)
return self.params_names == other.params_names

def get_params_names(self) -> Iterator[str]:
Expand All @@ -119,67 +121,51 @@ def get_params_names(self) -> Iterator[str]:
yield name


class DerivedParameter(ABC):
"""Represents a derived parameter generated by an Updatable object
This class provide the type that encapsulate a derived quantity computed
by an Updatable object during a statistical analysis.
"""

def __init__(self, section: str, name: str):
"""Constructs a new derived parameter."""
self.section: str = section
self.name: str = name

def get_full_name(self):
"""Constructs the full name using section--name."""
return f"{self.section}--{self.name}"

@abstractmethod
def get_val(self):
"""Returns the value contained."""


class DerivedParameterScalar(DerivedParameter):
class DerivedParameter:
"""Represents a derived scalar parameter generated by an Updatable object
This class provide the type that encapsulate a derived scalar quantity (represented
by a float) computed by an Updatable object during a statistical analysis.
"""

def __init__(self, section: str, name: str, val: float):
super().__init__(section, name)

"""Initialize the DerivedParameter object."""
self.section: str = section
self.name: str = name
if not isinstance(val, float):
raise TypeError(
"DerivedParameterScalar expects a float but received a "
+ str(type(val))
"DerivedParameter expects a float but received a " + str(type(val))
)
self.val: float = val

def get_val(self) -> float:
"""Return the value of this parameter."""
return self.val

def __eq__(self, other: object) -> bool:
"""Compare two DerivedParameterScalar objects for equality.
"""Compare two DerivedParameter objects for equality.
This implementation raises a NotImplemented exception unless both
objects are DerivedParameterScalar objects.
objects are DerivedParameter objects.
Two DerivedParameterScalar objects are equal if they have the same
Two DerivedParameter objects are equal if they have the same
section, name and value.
"""
if not isinstance(other, DerivedParameterScalar):
if not isinstance(other, DerivedParameter):
raise NotImplementedError(
"DerivedParameterScalar comparison is only implemented for "
"DerivedParameterScalar objects"
"DerivedParameter comparison is only implemented for "
"DerivedParameter objects"
)
return (
self.section == other.section
and self.name == other.name
and self.val == other.val
)

def get_full_name(self):
"""Constructs the full name using section--name."""
return f"{self.section}--{self.name}"


class DerivedParameterCollection:
"""Represents a list of DerivedParameter objects."""
Expand All @@ -189,8 +175,8 @@ def __init__(self, derived_parameters: Sequence[DerivedParameter]):

if not all(isinstance(x, DerivedParameter) for x in derived_parameters):
raise TypeError(
"DerivedParameterCollection expects a list of DerivedParameter but "
"received a " + str([str(type(x)) for x in derived_parameters])
"DerivedParameterCollection expects a list of DerivedParameter"
"but received a " + str([str(type(x)) for x in derived_parameters])
)

self.derived_parameters: Dict[str, DerivedParameter] = {}
Expand Down
4 changes: 2 additions & 2 deletions tests/likelihood/lkdir/lkmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import sacc
from firecrown.parameters import DerivedParameterCollection, DerivedParameterScalar
from firecrown.parameters import DerivedParameterCollection, DerivedParameter
from firecrown.likelihood.likelihood import Likelihood, NamedParameters
from firecrown.modeling_tools import ModelingTools
from firecrown import parameters
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(self):

def _get_derived_parameters(self) -> DerivedParameterCollection:
return DerivedParameterCollection(
[DerivedParameterScalar("derived_section", "derived_param0", 1.0)]
[DerivedParameter("derived_section", "derived_param0", 1.0)]
)

def read(self, sacc_data: sacc.Sacc) -> None:
Expand Down
93 changes: 61 additions & 32 deletions tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from firecrown.parameters import RequiredParameters, parameter_get_full_name, ParamsMap
from firecrown.parameters import (
DerivedParameterScalar,
DerivedParameter,
DerivedParameterCollection,
register_new_updatable_parameter,
create,
Expand Down Expand Up @@ -44,7 +44,14 @@ def test_register_new_updatable_parameter_with_float_arg():
InternalParameter ."""
a_parameter = register_new_updatable_parameter(1.5)
assert isinstance(a_parameter, InternalParameter)
assert a_parameter.value == 1.5
assert a_parameter.get_value() == 1.5


def test_setting_internal_parameter():
a_parameter = register_new_updatable_parameter(1.0)
assert a_parameter.value == 1.0
a_parameter.set_value(2.0)
assert a_parameter.value == 2.0


def test_register_new_updatable_parameter_with_wrong_arg():
Expand All @@ -63,6 +70,19 @@ def test_required_parameters_length():
assert len(b) == 2


def test_required_parameters_equality_testing():
a1 = RequiredParameters(["a"])
a2 = RequiredParameters(["a"])
assert a1 == a2
assert a1 is not a2
b = RequiredParameters(["b"])
assert a1 != b
with pytest.raises(
TypeError, match="Cannot compare a RequiredParameter to an object of type int"
):
_ = a1 == 10


def test_get_params_names_does_not_allow_mutation():
"""The caller of RequiredParameters.get_params_names should not be able to modify
the state of the object on which the call was made."""
Expand Down Expand Up @@ -104,7 +124,7 @@ def test_parameter_get_full_name_without_prefix():


def test_derived_parameter_scalar():
derived_param = DerivedParameterScalar("sec1", "name1", 3.14)
derived_param = DerivedParameter("sec1", "name1", 3.14)

assert isinstance(derived_param.get_val(), float)
assert derived_param.get_val() == 3.14
Expand All @@ -115,35 +135,44 @@ def test_derived_parameter_wrong_type():
"""Try instantiating DerivedParameter objects with wrong types."""

with pytest.raises(TypeError):
_ = DerivedParameterScalar( # pylint: disable-msg=E0110,W0612
_ = DerivedParameter( # pylint: disable-msg=E0110,W0612
"sec1", "name1", "not a float" # type: ignore
)
with pytest.raises(TypeError):
_ = DerivedParameterScalar( # pylint: disable-msg=E0110,W0612
_ = DerivedParameter( # pylint: disable-msg=E0110,W0612
"sec1", "name1", [3.14] # type: ignore
)
with pytest.raises(TypeError):
_ = DerivedParameterScalar( # pylint: disable-msg=E0110,W0612
_ = DerivedParameter( # pylint: disable-msg=E0110,W0612
"sec1", "name1", np.array([3.14]) # type: ignore
)


def test_derived_parameters_collection():
olist = [
DerivedParameterScalar("sec1", "name1", 3.14),
DerivedParameterScalar("sec2", "name2", 2.72),
DerivedParameter("sec1", "name1", 3.14),
DerivedParameter("sec2", "name2", 2.72),
]
orig = DerivedParameterCollection(olist)
clist = orig.get_derived_list()
clist.append(DerivedParameterScalar("sec3", "name3", 0.58))
clist.append(DerivedParameter("sec3", "name3", 0.58))
assert orig.get_derived_list() == olist


def test_derived_parameters_collection_rejects_bad_list():
badlist = [1, 3, 5]
with pytest.raises(TypeError):
# We have to tell mypy to ignore the type error on the
# next line, because it is the very type error we are
# testing.
_ = DerivedParameterCollection(badlist) # type: ignore


def test_derived_parameters_collection_add():
olist = [
DerivedParameterScalar("sec1", "name1", 3.14),
DerivedParameterScalar("sec2", "name2", 2.72),
DerivedParameterScalar("sec2", "name3", 0.58),
DerivedParameter("sec1", "name1", 3.14),
DerivedParameter("sec2", "name2", 2.72),
DerivedParameter("sec2", "name3", 0.58),
]
dpc1 = DerivedParameterCollection(olist)
dpc2 = None
Expand All @@ -158,16 +187,16 @@ def test_derived_parameters_collection_add():

def test_derived_parameters_collection_add_iter():
olist1 = [
DerivedParameterScalar("sec1", "name1", 3.14),
DerivedParameterScalar("sec2", "name2", 2.72),
DerivedParameterScalar("sec2", "name3", 0.58),
DerivedParameter("sec1", "name1", 3.14),
DerivedParameter("sec2", "name2", 2.72),
DerivedParameter("sec2", "name3", 0.58),
]
dpc1 = DerivedParameterCollection(olist1)

olist2 = [
DerivedParameterScalar("sec3", "name1", 3.14e1),
DerivedParameterScalar("sec3", "name2", 2.72e1),
DerivedParameterScalar("sec3", "name3", 0.58e1),
DerivedParameter("sec3", "name1", 3.14e1),
DerivedParameter("sec3", "name2", 2.72e1),
DerivedParameter("sec3", "name3", 0.58e1),
]
dpc2 = DerivedParameterCollection(olist2)

Expand All @@ -181,35 +210,35 @@ def test_derived_parameters_collection_add_iter():


def test_derived_parameter_eq():
dv1 = DerivedParameterScalar("sec1", "name1", 3.14)
dv2 = DerivedParameterScalar("sec1", "name1", 3.14)
dv1 = DerivedParameter("sec1", "name1", 3.14)
dv2 = DerivedParameter("sec1", "name1", 3.14)

assert dv1 == dv2


def test_derived_parameter_eq_invalid():
dv1 = DerivedParameterScalar("sec1", "name1", 3.14)
dv1 = DerivedParameter("sec1", "name1", 3.14)

with pytest.raises(
NotImplementedError,
match="DerivedParameterScalar comparison is only "
"implemented for DerivedParameterScalar objects",
match="DerivedParameter comparison is only "
"implemented for DerivedParameter objects",
):
_ = dv1 == 1.0


def test_derived_parameters_collection_eq():
olist1 = [
DerivedParameterScalar("sec1", "name1", 3.14),
DerivedParameterScalar("sec2", "name2", 2.72),
DerivedParameterScalar("sec2", "name3", 0.58),
DerivedParameter("sec1", "name1", 3.14),
DerivedParameter("sec2", "name2", 2.72),
DerivedParameter("sec2", "name3", 0.58),
]
dpc1 = DerivedParameterCollection(olist1)

olist2 = [
DerivedParameterScalar("sec1", "name1", 3.14),
DerivedParameterScalar("sec2", "name2", 2.72),
DerivedParameterScalar("sec2", "name3", 0.58),
DerivedParameter("sec1", "name1", 3.14),
DerivedParameter("sec2", "name2", 2.72),
DerivedParameter("sec2", "name3", 0.58),
]
dpc2 = DerivedParameterCollection(olist2)

Expand All @@ -218,9 +247,9 @@ def test_derived_parameters_collection_eq():

def test_derived_parameters_collection_eq_invalid():
olist1 = [
DerivedParameterScalar("sec1", "name1", 3.14),
DerivedParameterScalar("sec2", "name2", 2.72),
DerivedParameterScalar("sec2", "name3", 0.58),
DerivedParameter("sec1", "name1", 3.14),
DerivedParameter("sec2", "name2", 2.72),
DerivedParameter("sec2", "name3", 0.58),
]
dpc1 = DerivedParameterCollection(olist1)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_updatable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from firecrown.parameters import (
RequiredParameters,
ParamsMap,
DerivedParameterScalar,
DerivedParameter,
DerivedParameterCollection,
)

Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(self):
self.B = parameters.register_new_updatable_parameter()

def _get_derived_parameters(self) -> DerivedParameterCollection:
derived_scale = DerivedParameterScalar("Section", "Name", self.A + self.B)
derived_scale = DerivedParameter("Section", "Name", self.A + self.B)
derived_parameters = DerivedParameterCollection([derived_scale])

return derived_parameters
Expand Down Expand Up @@ -316,7 +316,7 @@ def test_nesting_updatables_derived_parameters(nested_updatables):

base.update(params)

derived_scale = DerivedParameterScalar("Section", "Name", 9.0)
derived_scale = DerivedParameter("Section", "Name", 9.0)
derived_parameters = DerivedParameterCollection([derived_scale])

assert base.get_derived_parameters() == derived_parameters
Expand Down

0 comments on commit 848b9c1

Please sign in to comment.