Skip to content

Commit

Permalink
add trim_cube process (#278)
Browse files Browse the repository at this point in the history
* add trim_cube process

* fix and add more tests

* run pre-commit
  • Loading branch information
ValentinaHutter authored Sep 16, 2024
1 parent 1091de5 commit 712cf10
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
17 changes: 17 additions & 0 deletions openeo_processes_dask/process_implementations/cubes/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,23 @@ def create_data_cube() -> RasterCube:
return xr.DataArray()


def trim_cube(data) -> RasterCube:
for dim in data.dims:
if (
dim in data.openeo.temporal_dims
or dim in data.openeo.band_dims
or dim in data.openeo.other_dims
):
values = data[dim].values
other_dims = [d for d in data.dims if d != dim]
available_data = values[(np.isnan(data)).all(dim=other_dims) == 0]
if len(available_data) == 0:
raise ValueError(f"Data contains NaN values only. ")
data = data.sel({dim: available_data})

return data


def dimension_labels(data: RasterCube, dimension: str) -> ArrayLike:
if dimension not in data.dims:
raise DimensionNotAvailable(
Expand Down
20 changes: 20 additions & 0 deletions tests/test_dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
drop_dimension,
rename_dimension,
rename_labels,
trim_cube,
)
from openeo_processes_dask.process_implementations.exceptions import (
DimensionLabelCountMismatch,
Expand Down Expand Up @@ -124,3 +125,22 @@ def test_rename_labels(temporal_interval, bounding_box, random_raster_data):
dimension="bands",
target=["B02", "B03", "B04", "B05", "B08", "B11", "B12"],
)


@pytest.mark.parametrize("size", [(30, 30, 20, 4)])
@pytest.mark.parametrize("dtype", [np.float32])
def test_trim_cube(temporal_interval, bounding_box, random_raster_data):
input_cube = create_fake_rastercube(
data=random_raster_data,
spatial_extent=bounding_box,
temporal_extent=temporal_interval,
bands=["B02", "B03", "B04", "B08"],
backend="dask",
)
input_cube[:, :, :, 2] = np.zeros((30, 30, 20)) * np.nan
output_cube = trim_cube(input_cube)
assert output_cube.shape == (30, 30, 20, 3)

all_nan = input_cube * np.nan
with pytest.raises(ValueError):
output_cube = trim_cube(all_nan)

0 comments on commit 712cf10

Please sign in to comment.