From f08da40b9502e720dd1b5a5596bfd3d75ebe6af7 Mon Sep 17 00:00:00 2001 From: Lukas Weidenholzer Date: Mon, 14 Aug 2023 18:39:08 +0200 Subject: [PATCH] support axis keyword in array_contains --- openeo_processes_dask/process_implementations/arrays.py | 9 +++------ tests/test_arrays.py | 6 +++++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/openeo_processes_dask/process_implementations/arrays.py b/openeo_processes_dask/process_implementations/arrays.py index 58734514..67fd536e 100644 --- a/openeo_processes_dask/process_implementations/arrays.py +++ b/openeo_processes_dask/process_implementations/arrays.py @@ -126,7 +126,7 @@ def array_concat(array1: ArrayLike, array2: ArrayLike) -> ArrayLike: return concat -def array_contains(data: ArrayLike, value: Any) -> bool: +def array_contains(data: ArrayLike, value: Any, axis=None) -> bool: # TODO: Contrary to the process spec, our implementation does interpret temporal strings before checking them here # This is somewhat implicit in how we currently parse parameters, so cannot be easily changed. @@ -137,13 +137,10 @@ def array_contains(data: ArrayLike, value: Any) -> bool: value_is_valid = True if not value_is_valid: return False - - if len(np.shape(data)) != 1: - return False if pd.isnull(value): - return np.isnan(data).any() + return np.isnan(data).any(axis=axis) else: - return np.isin(data, value).any() + return np.isin(data, value).any(axis=axis) def array_find( diff --git a/tests/test_arrays.py b/tests/test_arrays.py index 439572d1..586eca95 100644 --- a/tests/test_arrays.py +++ b/tests/test_arrays.py @@ -145,7 +145,6 @@ def test_array_concat(array1, array2, expected): ([1, 2, 3], "2", False), ([1, 2, np.nan], np.nan, True), ([[2, 1], [3, 4]], [1, 2], False), - ([[2, 1], [3, 4]], 2, False), ([1, 2, 3], np.int64(2), True), ([1.1, 2.2, 3.3], np.float64(2.2), True), ([True, False, False], np.bool_(True), True), @@ -159,6 +158,11 @@ def test_array_contains(data, value, expected): assert dask_result == expected or dask_result.compute() == expected +def test_array_contains_axis(): + data = np.array([[4, 5, 6], [5, 7, 9]]) + result = array_contains(data, 5, axis=1) + + def test_array_contains_object_dtype(): assert not array_contains([{"a": "b"}, {"c": "d"}], {"a": "b"}) assert not array_contains(np.array([{"a": "b"}, {"c": "d"}]), {"a": "b"})