From 712cf1061ae205587cdb1d3aa38ca410788706ba Mon Sep 17 00:00:00 2001 From: ValentinaHutter <85164505+ValentinaHutter@users.noreply.github.com> Date: Mon, 16 Sep 2024 16:02:17 +0200 Subject: [PATCH] add trim_cube process (#278) * add trim_cube process * fix and add more tests * run pre-commit --- .../process_implementations/cubes/general.py | 17 ++++++++++++++++ tests/test_dimensions.py | 20 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/openeo_processes_dask/process_implementations/cubes/general.py b/openeo_processes_dask/process_implementations/cubes/general.py index 12c9f73b..d55cd86a 100644 --- a/openeo_processes_dask/process_implementations/cubes/general.py +++ b/openeo_processes_dask/process_implementations/cubes/general.py @@ -38,6 +38,23 @@ def create_data_cube() -> RasterCube: return xr.DataArray() +def trim_cube(data) -> RasterCube: + for dim in data.dims: + if ( + dim in data.openeo.temporal_dims + or dim in data.openeo.band_dims + or dim in data.openeo.other_dims + ): + values = data[dim].values + other_dims = [d for d in data.dims if d != dim] + available_data = values[(np.isnan(data)).all(dim=other_dims) == 0] + if len(available_data) == 0: + raise ValueError(f"Data contains NaN values only. ") + data = data.sel({dim: available_data}) + + return data + + def dimension_labels(data: RasterCube, dimension: str) -> ArrayLike: if dimension not in data.dims: raise DimensionNotAvailable( diff --git a/tests/test_dimensions.py b/tests/test_dimensions.py index 2b22e28c..37ae373d 100644 --- a/tests/test_dimensions.py +++ b/tests/test_dimensions.py @@ -6,6 +6,7 @@ drop_dimension, rename_dimension, rename_labels, + trim_cube, ) from openeo_processes_dask.process_implementations.exceptions import ( DimensionLabelCountMismatch, @@ -124,3 +125,22 @@ def test_rename_labels(temporal_interval, bounding_box, random_raster_data): dimension="bands", target=["B02", "B03", "B04", "B05", "B08", "B11", "B12"], ) + + +@pytest.mark.parametrize("size", [(30, 30, 20, 4)]) +@pytest.mark.parametrize("dtype", [np.float32]) +def test_trim_cube(temporal_interval, bounding_box, random_raster_data): + input_cube = create_fake_rastercube( + data=random_raster_data, + spatial_extent=bounding_box, + temporal_extent=temporal_interval, + bands=["B02", "B03", "B04", "B08"], + backend="dask", + ) + input_cube[:, :, :, 2] = np.zeros((30, 30, 20)) * np.nan + output_cube = trim_cube(input_cube) + assert output_cube.shape == (30, 30, 20, 3) + + all_nan = input_cube * np.nan + with pytest.raises(ValueError): + output_cube = trim_cube(all_nan)