Skip to content

Commit

Permalink
fix: reduce_spatial (#164)
Browse files Browse the repository at this point in the history
* fix reduce_spatial

* Fix context

* Revert "fix reduce_spatial"

This reverts commit f1fdac3.

* Revert "Revert "fix reduce_spatial""

This reverts commit b3d45dd.

* Revert "Fix context"

This reverts commit c8fbcfc.
  • Loading branch information
clausmichele authored Oct 11, 2023
1 parent c254f6e commit 6e77c76
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ def reduce_spatial(
spatial_dims = data.openeo.spatial_dims if data.openeo.spatial_dims else None
return data.reduce(
reducer,
dimension=spatial_dims,
dim=spatial_dims,
keep_attrs=True,
context=context,
positional_parameters=positional_parameters,
named_parameters=named_parameters,
)
36 changes: 35 additions & 1 deletion tests/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import xarray as xr
from openeo_pg_parser_networkx.pg_schema import ParameterReference

from openeo_processes_dask.process_implementations.cubes.reduce import reduce_dimension
from openeo_processes_dask.process_implementations.cubes.reduce import (
reduce_dimension,
reduce_spatial,
)
from tests.general_checks import general_output_checks
from tests.mockdata import create_fake_rastercube

Expand Down Expand Up @@ -39,3 +42,34 @@ def test_reduce_dimension(
)

xr.testing.assert_equal(output_cube, input_cube.mean(dim="t"))


@pytest.mark.parametrize("size", [(30, 30, 20, 4)])
@pytest.mark.parametrize("dtype", [np.float32])
def test_reduce_spatial(
temporal_interval, bounding_box, random_raster_data, process_registry
):
input_cube = create_fake_rastercube(
data=random_raster_data,
spatial_extent=bounding_box,
temporal_extent=temporal_interval,
bands=["B02", "B03", "B04", "B08"],
backend="dask",
)

_process = partial(
process_registry["sum"].implementation,
ignore_nodata=True,
data=ParameterReference(from_parameter="data"),
)

output_cube = reduce_spatial(data=input_cube, reducer=_process)

general_output_checks(
input_cube=input_cube,
output_cube=output_cube,
verify_attrs=False,
verify_crs=True,
)

xr.testing.assert_equal(output_cube, input_cube.sum(dim=["x", "y"]))

0 comments on commit 6e77c76

Please sign in to comment.