Skip to content

Commit

Permalink
support axis keyword in array_contains
Browse files Browse the repository at this point in the history
  • Loading branch information
Lukas Weidenholzer committed Aug 14, 2023
1 parent b1a1196 commit f08da40
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
9 changes: 3 additions & 6 deletions openeo_processes_dask/process_implementations/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def array_concat(array1: ArrayLike, array2: ArrayLike) -> ArrayLike:
return concat


def array_contains(data: ArrayLike, value: Any) -> bool:
def array_contains(data: ArrayLike, value: Any, axis=None) -> bool:
# TODO: Contrary to the process spec, our implementation does interpret temporal strings before checking them here
# This is somewhat implicit in how we currently parse parameters, so cannot be easily changed.

Expand All @@ -137,13 +137,10 @@ def array_contains(data: ArrayLike, value: Any) -> bool:
value_is_valid = True
if not value_is_valid:
return False

if len(np.shape(data)) != 1:
return False
if pd.isnull(value):
return np.isnan(data).any()
return np.isnan(data).any(axis=axis)
else:
return np.isin(data, value).any()
return np.isin(data, value).any(axis=axis)


def array_find(
Expand Down
6 changes: 5 additions & 1 deletion tests/test_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def test_array_concat(array1, array2, expected):
([1, 2, 3], "2", False),
([1, 2, np.nan], np.nan, True),
([[2, 1], [3, 4]], [1, 2], False),
([[2, 1], [3, 4]], 2, False),
([1, 2, 3], np.int64(2), True),
([1.1, 2.2, 3.3], np.float64(2.2), True),
([True, False, False], np.bool_(True), True),
Expand All @@ -159,6 +158,11 @@ def test_array_contains(data, value, expected):
assert dask_result == expected or dask_result.compute() == expected


def test_array_contains_axis():
data = np.array([[4, 5, 6], [5, 7, 9]])
result = array_contains(data, 5, axis=1)


def test_array_contains_object_dtype():
assert not array_contains([{"a": "b"}, {"c": "d"}], {"a": "b"})
assert not array_contains(np.array([{"a": "b"}, {"c": "d"}]), {"a": "b"})
Expand Down

0 comments on commit f08da40

Please sign in to comment.