Skip to content

Commit

Permalink
Filter labels (#277)
Browse files Browse the repository at this point in the history
* update openeo-processes

* bump 2023.9.1

* bump 2023.9.0

* update filter_labels and add tests

* pre-commit run
  • Loading branch information
ValentinaHutter authored Sep 16, 2024
1 parent 4bd5081 commit 8512783
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 12 deletions.
23 changes: 17 additions & 6 deletions openeo_processes_dask/process_implementations/cubes/_filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import warnings
from typing import Callable
from typing import Any, Callable, Optional

import dask.array as da
import geopandas as gpd
Expand Down Expand Up @@ -87,16 +87,27 @@ def filter_temporal(
return filtered


def filter_labels(data: RasterCube, condition: Callable, dimension: str) -> RasterCube:
def filter_labels(
data: RasterCube, condition: Callable, dimension: str, context: Optional[Any] = None
) -> RasterCube:
if dimension not in data.dims:
raise DimensionNotAvailable(
f"Provided dimension ({dimension}) not found in data.dims: {data.dims}"
)

labels = data[dimension].values
label_mask = condition(x=labels)
label = labels[label_mask]
data = data.sel(**{dimension: label})
labels = np.array(data[dimension].values)
if not context:
context = {}
positional_parameters = {"x": 0}
named_parameters = {"x": labels, "context": context}
filter_condition = np.vectorize(condition)
filtered_labels = filter_condition(
labels,
positional_parameters=positional_parameters,
named_parameters=named_parameters,
)
label = np.argwhere(filtered_labels)
data = data.isel(**{dimension: label[0]})
return data


Expand Down
30 changes: 24 additions & 6 deletions tests/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import datetime
from functools import partial

import numpy as np
Expand All @@ -7,12 +8,7 @@
import xarray as xr
from openeo_pg_parser_networkx.pg_schema import ParameterReference, TemporalInterval

from openeo_processes_dask.process_implementations.cubes._filter import (
filter_bands,
filter_bbox,
filter_spatial,
filter_temporal,
)
from openeo_processes_dask.process_implementations.cubes._filter import *
from openeo_processes_dask.process_implementations.cubes.reduce import reduce_dimension
from openeo_processes_dask.process_implementations.exceptions import (
DimensionNotAvailable,
Expand Down Expand Up @@ -68,6 +64,28 @@ def test_filter_temporal(temporal_interval, bounding_box, random_raster_data):
filter_temporal(invalid_input_cube, temporal_interval)


@pytest.mark.parametrize("size", [(30, 30, 30, 3)])
@pytest.mark.parametrize("dtype", [np.uint8])
def test_filter_labels(
temporal_interval, bounding_box, random_raster_data, process_registry
):
input_cube = create_fake_rastercube(
data=random_raster_data,
spatial_extent=bounding_box,
temporal_extent=temporal_interval,
bands=["B02", "B03", "B04"],
backend="dask",
)
_process = partial(
process_registry["eq"].implementation,
y="B04",
x=ParameterReference(from_parameter="x"),
)

output_cube = filter_labels(data=input_cube, condition=_process, dimension="bands")
assert len(output_cube["bands"]) == 1


@pytest.mark.parametrize("size", [(1, 1, 1, 2)])
@pytest.mark.parametrize("dtype", [np.uint8])
def test_filter_bands(temporal_interval, bounding_box, random_raster_data):
Expand Down

0 comments on commit 8512783

Please sign in to comment.