-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Setup test fixtures and add example test for `aggregate_temporal_peri…
…od` (#3) * add test for aggregate temporal * remove poetry-lock hook * Update test suite with more flexible fixtures * Add extensive test for aggregate_temporal_period * fix typing in aggregate_temporal_period * remove obsolete alias * add test for apply * Add test for reduce_dimension
- Loading branch information
1 parent
166d49d
commit 9f7a958
Showing
9 changed files
with
236 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,32 @@ | ||
import logging | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from tests.mockdata import generate_fake_rastercube | ||
from openeo_pg_parser_networkx.pg_schema import BoundingBox, TemporalInterval | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@pytest.fixture | ||
def mock_rastercube_factory(): | ||
return generate_fake_rastercube | ||
def random_data(size, dtype, seed=42): | ||
rng = np.random.default_rng(seed) | ||
data = rng.integers(-100, 100, size=size) | ||
data = data.astype(dtype) | ||
return data | ||
|
||
|
||
@pytest.fixture | ||
def bounding_box(west=10.45, east=10.5, south=46.1, north=46.2, crs="EPSG:4326"): | ||
spatial_extent = { | ||
"west": west, | ||
"east": east, | ||
"south": south, | ||
"north": north, | ||
"crs": crs, | ||
} | ||
return BoundingBox.parse_obj(spatial_extent) | ||
|
||
|
||
@pytest.fixture | ||
def temporal_interval(interval=["2018-05-01", "2018-06-01"]): | ||
return TemporalInterval.parse_obj(interval) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Checks here are inspired by makepath/xarray-spatial/tests/general_checks.py | ||
|
||
import dask.array as da | ||
import numpy as np | ||
|
||
from openeo_processes_dask.process_implementations.data_model import RasterCube | ||
|
||
|
||
def general_output_checks( | ||
input_cube: RasterCube, | ||
output_cube: RasterCube, | ||
expected_results=None, | ||
verify_crs: bool = False, | ||
verify_attrs: bool = False, | ||
rtol=1e-06, | ||
): | ||
assert isinstance(output_cube.data, type(input_cube.data)) | ||
|
||
if verify_crs: | ||
assert input_cube.rio.crs == output_cube.rio.crs | ||
|
||
if verify_attrs: | ||
assert input_cube.attrs == output_cube.attrs | ||
|
||
if expected_results is not None: | ||
if isinstance(output_cube.data, np.ndarray): | ||
output_data = output_cube.data | ||
elif isinstance(output_cube.data, da.Array): | ||
output_data = output_cube.data.compute() | ||
else: | ||
raise TypeError(f"Unsupported array type: {type(output_cube.data)}") | ||
|
||
np.testing.assert_allclose( | ||
output_data, expected_results, equal_nan=True, rtol=rtol | ||
) | ||
|
||
|
||
def assert_numpy_equals_dask_numpy(numpy_cube, dask_cube, func): | ||
numpy_result = func(numpy_cube) | ||
dask_result = func(dask_cube) | ||
general_output_checks(dask_cube, dask_result) | ||
np.testing.assert_allclose( | ||
numpy_result.data, dask_result.data.compute(), equal_nan=True | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from functools import partial | ||
|
||
import numpy as np | ||
import pytest | ||
from openeo_pg_parser_networkx.pg_schema import TemporalInterval | ||
|
||
from openeo_processes_dask.process_implementations.cubes.aggregate import ( | ||
aggregate_temporal_period, | ||
) | ||
from openeo_processes_dask.process_implementations.math import mean | ||
from tests.general_checks import assert_numpy_equals_dask_numpy, general_output_checks | ||
from tests.mockdata import create_fake_rastercube | ||
|
||
|
||
@pytest.mark.parametrize("size", [(6, 5, 4, 4)]) | ||
@pytest.mark.parametrize("dtype", [np.float64]) | ||
@pytest.mark.parametrize( | ||
"temporal_extent,period,expected", | ||
[ | ||
(["2018-05-01", "2018-05-02"], "hour", 25), | ||
(["2018-05-01", "2018-06-01"], "day", 32), | ||
(["2018-05-01", "2018-06-01"], "week", 5), | ||
(["2018-05-01", "2018-06-01"], "month", 2), | ||
(["2018-01-01", "2018-12-31"], "season", 5), | ||
(["2018-01-01", "2018-12-31"], "year", 1), | ||
], | ||
) | ||
def test_aggregate_temporal_period( | ||
temporal_extent, period, expected, bounding_box, random_data | ||
): | ||
"""""" | ||
input_cube = create_fake_rastercube( | ||
data=random_data, | ||
spatial_extent=bounding_box, | ||
temporal_extent=TemporalInterval.parse_obj(temporal_extent), | ||
bands=["B02", "B03", "B04", "B08"], | ||
) | ||
output_cube = aggregate_temporal_period( | ||
data=input_cube, period=period, reducer=mean | ||
) | ||
|
||
general_output_checks( | ||
input_cube=input_cube, | ||
output_cube=output_cube, | ||
verify_attrs=True, | ||
verify_crs=True, | ||
) | ||
|
||
assert len(output_cube.t) == expected | ||
assert isinstance(output_cube.t.values[0], np.datetime64) | ||
|
||
|
||
@pytest.mark.parametrize("size", [(6, 5, 4, 4)]) | ||
@pytest.mark.parametrize("dtype", [np.int32, np.int64, np.float32, np.float64]) | ||
def test_aggregate_temporal_period_numpy_equals_dask( | ||
random_data, bounding_box, temporal_interval | ||
): | ||
numpy_cube = create_fake_rastercube( | ||
data=random_data, | ||
spatial_extent=bounding_box, | ||
temporal_extent=temporal_interval, | ||
bands=["B02", "B03", "B04", "B08"], | ||
backend="numpy", | ||
) | ||
dask_cube = create_fake_rastercube( | ||
data=random_data, | ||
spatial_extent=bounding_box, | ||
temporal_extent=temporal_interval, | ||
bands=["B02", "B03", "B04", "B08"], | ||
backend="dask", | ||
) | ||
|
||
func = partial(aggregate_temporal_period, reducer=mean, period="hour") | ||
assert_numpy_equals_dask_numpy( | ||
numpy_cube=numpy_cube, dask_cube=dask_cube, func=func | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from functools import partial | ||
|
||
import numpy as np | ||
import pytest | ||
from openeo_pg_parser_networkx.pg_schema import ParameterReference | ||
|
||
from openeo_processes_dask.core import process_registry | ||
from openeo_processes_dask.process_implementations.cubes.apply import apply | ||
from openeo_processes_dask.process_implementations.math import add | ||
from tests.general_checks import assert_numpy_equals_dask_numpy, general_output_checks | ||
from tests.mockdata import create_fake_rastercube | ||
|
||
|
||
@pytest.mark.parametrize("size", [(6, 5, 4, 4)]) | ||
@pytest.mark.parametrize("dtype", [np.float32]) | ||
def test_apply(temporal_interval, bounding_box, random_data): | ||
input_cube = create_fake_rastercube( | ||
data=random_data, | ||
spatial_extent=bounding_box, | ||
temporal_extent=temporal_interval, | ||
bands=["B02", "B03", "B04", "B08"], | ||
) | ||
|
||
_process = partial( | ||
process_registry["add"], y=1, x=ParameterReference(from_parameter="x") | ||
) | ||
|
||
output_cube = apply(data=input_cube, process=_process) | ||
|
||
general_output_checks( | ||
input_cube=input_cube, | ||
output_cube=output_cube, | ||
verify_attrs=True, | ||
verify_crs=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from functools import partial | ||
|
||
import numpy as np | ||
import pytest | ||
from openeo_pg_parser_networkx.pg_schema import ParameterReference | ||
|
||
from openeo_processes_dask.core import process_registry | ||
from openeo_processes_dask.process_implementations.cubes.reduce import reduce_dimension | ||
from tests.general_checks import general_output_checks | ||
from tests.mockdata import create_fake_rastercube | ||
|
||
|
||
@pytest.mark.parametrize("size", [(6, 5, 4, 4)]) | ||
@pytest.mark.parametrize("dtype", [np.float32]) | ||
def test_reduce_dimension(temporal_interval, bounding_box, random_data): | ||
input_cube = create_fake_rastercube( | ||
data=random_data, | ||
spatial_extent=bounding_box, | ||
temporal_extent=temporal_interval, | ||
bands=["B02", "B03", "B04", "B08"], | ||
) | ||
|
||
_process = partial( | ||
process_registry["mean"], y=1, data=ParameterReference(from_parameter="data") | ||
) | ||
|
||
output_cube = reduce_dimension(data=input_cube, reducer=_process, dimension="t") | ||
|
||
general_output_checks( | ||
input_cube=input_cube, | ||
output_cube=output_cube, | ||
verify_attrs=False, | ||
verify_crs=True, | ||
) |