Skip to content

Commit

Permalink
add test of randomized value matching
Browse files Browse the repository at this point in the history
  • Loading branch information
danielfromearth committed Oct 20, 2023
1 parent 8a523e3 commit fa7ee07
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 15 deletions.
6 changes: 3 additions & 3 deletions ncompare/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def compare_multiple_random_values(
):
"""Iterate through N random samples, and evaluate whether the differences exceed a threshold."""
# Open a variable from each NetCDF
nc_var_a = xr.open_dataset(nc_a, backend_kwargs={"group": groupname})[varname]
nc_var_b = xr.open_dataset(nc_b, backend_kwargs={"group": groupname})[varname]
nc_var_a = xr.open_dataset(nc_a, backend_kwargs={"group": groupname}).variables[varname]
nc_var_b = xr.open_dataset(nc_b, backend_kwargs={"group": groupname}).variables[varname]

num_mismatches = 0
for _ in range(num_comparisons):
Expand Down Expand Up @@ -448,7 +448,7 @@ def _var_properties(group: Union[netCDF4.Dataset, netCDF4.Group], varname: str)


def _match_random_value(
out: Outputter, nc_var_a: xr.DataArray, nc_var_b: xr.DataArray, thresh: float = 1e-6
out: Outputter, nc_var_a: xr.Variable, nc_var_b: xr.Variable, thresh: float = 1e-6
) -> Union[bool, None]:
"""Check whether a randomly selected data point matches between two variables.
Expand Down
17 changes: 12 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from xarray import Dataset

from . import data_for_tests_dir
from ncompare.printing import Outputter

@pytest.fixture(scope="session")
def outputter_obj():
return Outputter()

@pytest.fixture(scope="session")
def temp_data_dir(tmpdir_factory) -> Path:
Expand All @@ -14,11 +19,12 @@ def temp_data_dir(tmpdir_factory) -> Path:
@pytest.fixture(scope="session")
def ds_3dims_2vars_4coords(temp_data_dir) -> Path:
ds = Dataset(
dict(
data_vars=dict(
# "normal" (Gaussian) distribution of mean 0 and variance 1
z1=(["y", "x"], np.random.randn(2, 8)),
z2=(["time", "y"], np.random.randn(10, 2)),
),
dict(
coords=dict(
x=("x", np.linspace(0, 1.0, 8)),
time=("time", np.linspace(0, 1.0, 10)),
c=("y", ["a", "b"]),
Expand All @@ -33,12 +39,13 @@ def ds_3dims_2vars_4coords(temp_data_dir) -> Path:
@pytest.fixture(scope="session")
def ds_4dims_3vars_5coords(temp_data_dir):
ds = Dataset(
dict(
z1=(["y", "x"], np.random.randn(2, 8)),
data_vars=dict(
# "normal" (Gaussian) distribution of mean 10 and standard deviation 2.5
z1=(["y", "x"], 10 + 2.5 * np.random.randn(2, 8)),
z2=(["time", "y"], np.random.randn(10, 2)),
z3=(["y", "z"], np.random.randn(2, 9)),
),
dict(
coords=dict(
x=("x", np.linspace(0, 1.0, 8)),
time=("time", np.linspace(0, 1.0, 10)),
c=("y", ["a", "b"]),
Expand Down
12 changes: 11 additions & 1 deletion tests/test_netcdf_compare.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from ncompare.core import compare
import xarray as xr

from ncompare.core import compare, _match_random_value


def test_dataset_compare_does_not_raise_exception(ds_3dims_2vars_4coords, ds_4dims_3vars_5coords):
compare(ds_3dims_2vars_4coords, ds_4dims_3vars_5coords)

def test_dataset_compare_does_not_raise_exception_2(ds_3dims_2vars_4coords, ds_3dims_3vars_4coords_1group):
compare(ds_3dims_2vars_4coords, ds_3dims_3vars_4coords_1group)

def test_matching_random_values(ds_3dims_2vars_4coords, ds_4dims_3vars_5coords,
ds_3dims_3vars_4coords_1group, outputter_obj):
variable_array_1 = xr.open_dataset(ds_3dims_2vars_4coords).variables['z1']
variable_array_2 = xr.open_dataset(ds_4dims_3vars_5coords).variables['z1']

assert _match_random_value(outputter_obj, variable_array_1, variable_array_1, ) is True
assert _match_random_value(outputter_obj, variable_array_1, variable_array_2, ) is False
6 changes: 0 additions & 6 deletions tests/test_printing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
import pytest

from ncompare.printing import Outputter


@pytest.fixture
def outputter_obj():
return Outputter()

def test_list_of_strings_diff(outputter_obj):
left, right, both = outputter_obj.lists_diff(['hey', 'yo', 'beebop'],
Expand Down

0 comments on commit fa7ee07

Please sign in to comment.