diff --git a/openeo_processes_dask/process_implementations/cubes/_filter.py b/openeo_processes_dask/process_implementations/cubes/_filter.py index 7eb50f4c..60939cb3 100644 --- a/openeo_processes_dask/process_implementations/cubes/_filter.py +++ b/openeo_processes_dask/process_implementations/cubes/_filter.py @@ -1,7 +1,7 @@ import json import logging import warnings -from typing import Callable +from typing import Any, Callable, Optional import dask.array as da import geopandas as gpd @@ -87,16 +87,27 @@ def filter_temporal( return filtered -def filter_labels(data: RasterCube, condition: Callable, dimension: str) -> RasterCube: +def filter_labels( + data: RasterCube, condition: Callable, dimension: str, context: Optional[Any] = None +) -> RasterCube: if dimension not in data.dims: raise DimensionNotAvailable( f"Provided dimension ({dimension}) not found in data.dims: {data.dims}" ) - labels = data[dimension].values - label_mask = condition(x=labels) - label = labels[label_mask] - data = data.sel(**{dimension: label}) + labels = np.array(data[dimension].values) + if not context: + context = {} + positional_parameters = {"x": 0} + named_parameters = {"x": labels, "context": context} + filter_condition = np.vectorize(condition) + filtered_labels = filter_condition( + labels, + positional_parameters=positional_parameters, + named_parameters=named_parameters, + ) + label = np.argwhere(filtered_labels) + data = data.isel(**{dimension: label[0]}) return data diff --git a/tests/test_filter.py b/tests/test_filter.py index cf7f840f..2b4c2f67 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -1,4 +1,5 @@ import copy +import datetime from functools import partial import numpy as np @@ -7,12 +8,7 @@ import xarray as xr from openeo_pg_parser_networkx.pg_schema import ParameterReference, TemporalInterval -from openeo_processes_dask.process_implementations.cubes._filter import ( - filter_bands, - filter_bbox, - filter_spatial, - filter_temporal, -) +from openeo_processes_dask.process_implementations.cubes._filter import * from openeo_processes_dask.process_implementations.cubes.reduce import reduce_dimension from openeo_processes_dask.process_implementations.exceptions import ( DimensionNotAvailable, @@ -68,6 +64,28 @@ def test_filter_temporal(temporal_interval, bounding_box, random_raster_data): filter_temporal(invalid_input_cube, temporal_interval) +@pytest.mark.parametrize("size", [(30, 30, 30, 3)]) +@pytest.mark.parametrize("dtype", [np.uint8]) +def test_filter_labels( + temporal_interval, bounding_box, random_raster_data, process_registry +): + input_cube = create_fake_rastercube( + data=random_raster_data, + spatial_extent=bounding_box, + temporal_extent=temporal_interval, + bands=["B02", "B03", "B04"], + backend="dask", + ) + _process = partial( + process_registry["eq"].implementation, + y="B04", + x=ParameterReference(from_parameter="x"), + ) + + output_cube = filter_labels(data=input_cube, condition=_process, dimension="bands") + assert len(output_cube["bands"]) == 1 + + @pytest.mark.parametrize("size", [(1, 1, 1, 2)]) @pytest.mark.parametrize("dtype", [np.uint8]) def test_filter_bands(temporal_interval, bounding_box, random_raster_data):