Skip to content

Commit

Permalink
Testing new measurement code.
Browse files Browse the repository at this point in the history
  • Loading branch information
vitenti committed Jul 6, 2024
1 parent 6fded03 commit d6d5056
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 0 deletions.
123 changes: 123 additions & 0 deletions tests/metadata/test_metadata_two_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
type_to_sacc_string_harmonic as harmonic,
type_to_sacc_string_real as real,
Window,
TwoPointMeasurement,
)
from firecrown.likelihood.source import SourceGalaxy
from firecrown.likelihood.two_point import TwoPoint
Expand Down Expand Up @@ -681,6 +682,13 @@ def test_inferred_galaxy_zdist_serialization(harmonic_bin_1: InferredGalaxyZDist
assert harmonic_bin_1 == recovered


def test_two_point_xy_str(
harmonic_bin_1: InferredGalaxyZDist, harmonic_bin_2: InferredGalaxyZDist
):
xy = TwoPointXY(x=harmonic_bin_1, y=harmonic_bin_2)
assert str(xy) == f"({harmonic_bin_1.bin_name}, {harmonic_bin_2.bin_name})"


def test_two_point_xy_serialization(
harmonic_bin_1: InferredGalaxyZDist, harmonic_bin_2: InferredGalaxyZDist
):
Expand All @@ -690,6 +698,16 @@ def test_two_point_xy_serialization(
# is.
recovered = TwoPointXY.from_yaml(s)
assert xy == recovered
assert str(xy) == str(recovered)


def test_two_point_cells_str(
harmonic_bin_1: InferredGalaxyZDist, harmonic_bin_2: InferredGalaxyZDist
):
ells = np.array(np.linspace(0, 100, 100), dtype=np.int64)
xy = TwoPointXY(x=harmonic_bin_1, y=harmonic_bin_2)
cells = TwoPointCells(ells=ells, XY=xy)
assert str(cells) == f"{str(xy)}[{cells.get_sacc_name()}]"


def test_two_point_cells_serialization(
Expand All @@ -701,6 +719,8 @@ def test_two_point_cells_serialization(
s = cells.to_yaml()
recovered = TwoPointCells.from_yaml(s)
assert cells == recovered
assert str(xy) == str(recovered.XY)
assert str(cells) == str(recovered)


def test_window_serialization(window_1: Window):
Expand All @@ -724,6 +744,8 @@ def test_two_point_xi_theta_serialization(
s = xi_theta.to_yaml()
recovered = TwoPointXiTheta.from_yaml(s)
assert xi_theta == recovered
assert str(xy) == str(recovered.XY)
assert str(xi_theta) == str(recovered)


def test_two_point_from_metadata_cells(
Expand Down Expand Up @@ -816,3 +838,104 @@ def test_two_point_from_metadata_cells_unsupported_type(wl_factory, nc_factory):
match="Measurement .* not supported!",
):
TwoPoint.from_metadata_harmonic([cells], wl_factory, nc_factory)


def test_two_point_measurement():

data = np.array([1, 2, 3, 4, 5])
indices = np.array([1, 2, 3, 4, 5])
covariance_name = "cov"
measure = TwoPointMeasurement(
data=data,
indices=indices,
covariance_name=covariance_name,
)
assert_array_equal(measure.data, data)
assert_array_equal(measure.indices, indices)
assert measure.covariance_name == covariance_name


def test_two_point_measurement_invalid_data():
data = np.array([1, 2, 3, 4, 5]).reshape(-1, 1)
indices = np.array([1, 2, 3, 4, 5])
covariance_name = "cov"
with pytest.raises(
ValueError,
match="Data should be a 1D array.",
):
TwoPointMeasurement(
data=data,
indices=indices,
covariance_name=covariance_name,
)


def test_two_point_measurement_invalid_indices():
data = np.array([1, 2, 3, 4, 5])
indices = np.array([1, 2, 3, 4, 5]).reshape(-1, 1)
covariance_name = "cov"
with pytest.raises(
ValueError,
match="Data and indices should have the same shape.",
):
TwoPointMeasurement(
data=data,
indices=indices,
covariance_name=covariance_name,
)


def test_two_point_measurement_eq():
data = np.array([1, 2, 3, 4, 5])
indices = np.array([1, 2, 3, 4, 5])
covariance_name = "cov"
measure_1 = TwoPointMeasurement(
data=data,
indices=indices,
covariance_name=covariance_name,
)
measure_2 = TwoPointMeasurement(
data=data,
indices=indices,
covariance_name=covariance_name,
)
assert measure_1 == measure_2


def test_two_point_measurement_neq():
data = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
indices = np.array([1, 2, 3, 4, 5])
covariance_name = "cov"
measure_1 = TwoPointMeasurement(
data=data,
indices=indices,
covariance_name=covariance_name,
)
measure_2 = TwoPointMeasurement(
data=data,
indices=indices,
covariance_name="cov2",
)
assert measure_1 != measure_2
measure_3 = TwoPointMeasurement(
data=data,
indices=indices,
covariance_name=covariance_name,
)
measure_4 = TwoPointMeasurement(
data=data,
indices=indices + 1,
covariance_name=covariance_name,
)
assert measure_3 != measure_4
measure_5 = TwoPointMeasurement(
data=data,
indices=indices,
covariance_name=covariance_name,
)
measure_6 = TwoPointMeasurement(
data=data + 1.0,
indices=indices,
covariance_name=covariance_name,
)
assert measure_5 != measure_6
19 changes: 19 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
upper_triangle_indices,
save_to_sacc,
compare_optional_arrays,
compare_optionals,
base_model_from_yaml,
)

Expand Down Expand Up @@ -74,3 +75,21 @@ def test_compare_optional_arrays_():
def test_base_model_from_yaml_wrong():
with pytest.raises(ValueError):
_ = base_model_from_yaml(str, "wrong")


def test_compare_optionals():
x = "test"
y = "test"
assert compare_optionals(x, y)

z = "test2"
assert not compare_optionals(x, z)

a = None
b = None
assert compare_optionals(a, b)

q = np.array([1, 2, 3])
assert not compare_optionals(q, a)

assert not compare_optionals(a, q)

0 comments on commit d6d5056

Please sign in to comment.