diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index db4237dbd069..8e0162660951 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -61,6 +61,17 @@ def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool) -> Array: return (ar1_flat[:, None] == ar2_flat[None, :]).any(-1) +def _concat_unique(arr1: Array, arr2: Array) -> tuple[Array, Array]: + """Utility to concatenate the unique values from two arrays.""" + arr1, arr2 = ravel(arr1), ravel(arr2) + arr1, num_unique1 = _unique(arr1, axis=0, size=arr1.size, return_true_size=True) + arr2, num_unique2 = _unique(arr2, axis=0, size=arr2.size, return_true_size=True) + arr = zeros(arr1.size + arr2.size, dtype=dtypes.result_type(arr1, arr2)) + arr = lax.dynamic_update_slice(arr, arr1, (0,)) + arr = lax.dynamic_update_slice(arr, arr2, (num_unique1,)) + return arr, num_unique1 + num_unique2 + + def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *, size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Compute the set difference of two 1D arrays. @@ -220,7 +231,39 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike, return cast(Array, out) -def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Array: +@partial(jit, static_argnames=['assume_unique', 'size']) +def _setxor1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, *, + assume_unique: bool, size: int, ) -> Array: + # Ensured by caller + assert arr1.ndim == arr2.ndim == 1 + assert arr1.dtype == arr2.dtype + + if assume_unique: + arr = concatenate([arr1, arr2]) + aux = sort(concatenate([arr1, arr2])) + flag = concatenate((bool(aux.size), aux[1:] != aux[:-1], True), axis=None) + else: + arr, num_unique = _concat_unique(arr1, arr2) + mask = arange(arr.size + 1) < num_unique + 1 + _, aux = lax.sort([~mask[1:], arr], is_stable=True, num_keys=2) + flag = mask & concatenate((bool(aux.size), aux[1:] != aux[:-1], False), + axis=None).at[num_unique].set(True) + aux_mask = flag[1:] & flag[:-1] + num_results = aux_mask.sum() + if aux.size: + indices = nonzero(aux_mask, size=size, fill_value=len(aux))[0] + vals = aux.at[indices].get(mode='fill', fill_value=0) + else: + vals = zeros(size, aux.dtype) + if fill_value is None: + vals = where(arange(len(vals)) < num_results, vals, vals.max()) + return where(arange(len(vals)) < num_results, vals, vals.min()) + else: + return where(arange(len(vals)) < num_results, vals, fill_value) + + +def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *, + size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Compute the set-wise xor of elements in two arrays. JAX implementation of :func:`numpy.setxor1d`. @@ -234,6 +277,12 @@ def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Arr assume_unique: if True, assume the input arrays contain unique values. This allows a more efficient implementation, but if ``assume_unique`` is True and the input arrays contain duplicates, the behavior is undefined. default: False. + size: if specified, return only the first ``size`` sorted elements. If there are fewer + elements than ``size`` indicates, the return value will be padded with ``fill_value``, + and returned indices will be padded with an out-of-bound index. + fill_value: when ``size`` is specified and there are fewer than the indicated number of + elements, fill the remaining entries ``fill_value``. Defaults to the smallest value + in the xor result. Returns: An array of values that are found in exactly one of the input arrays. @@ -250,22 +299,21 @@ def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Arr Array([1, 2, 5, 6], dtype=int32) """ check_arraylike("setxor1d", ar1, ar2) - ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()") - ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()") + arr1, arr2 = promote_dtypes(ravel(ar1), ravel(ar2)) + del ar1, ar2 - ar1 = ravel(ar1) - ar2 = ravel(ar2) + if size is not None: + return _setxor1d_size(arr1, arr2, fill_value=fill_value, + assume_unique=assume_unique, size=size) if not assume_unique: - ar1 = unique(ar1) - ar2 = unique(ar2) - - aux = concatenate((ar1, ar2)) + arr1 = unique(arr1) + arr2 = unique(arr2) + aux = concatenate((arr1, arr2)) if aux.size == 0: return aux - aux = sort(aux) - flag = concatenate((array([True]), aux[1:] != aux[:-1], array([True]))) + flag = concatenate((True, aux[1:] != aux[:-1], True), axis=None) return aux[flag[1:] & flag[:-1]] @@ -312,7 +360,7 @@ def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, as arr1, ind1, num_unique1 = _unique(arr1, 0, size=arr1.size, return_index=True, return_true_size=True, fill_value=0) arr2, ind2, num_unique2 = _unique(arr2, 0, size=arr2.size, return_index=True, return_true_size=True, fill_value=0) arr = zeros(arr1.size + arr2.size, dtype=dtypes.result_type(arr1, arr2)) - arr = arr.at[:arr1.size].set(arr1) + arr = lax.dynamic_update_slice(arr, arr1, (0,)) arr = lax.dynamic_update_slice(arr, arr2, (num_unique1,)) mask = arange(arr.size) < num_unique1 + num_unique2 _, aux, aux_sort_indices = lax.sort([~mask, arr, arange(arr.size)], is_stable=True, num_keys=2) @@ -326,8 +374,11 @@ def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, as # and vals[num_results:] contains the appropriate fill_value. aux_mask = (aux[1:] == aux[:-1]) & mask[1:] num_results = aux_mask.sum() - val_indices = nonzero(aux_mask, size=size, fill_value=aux.size)[0] - vals = aux.at[val_indices].get(mode='fill', fill_value=0) + if aux.size: + val_indices = nonzero(aux_mask, size=size, fill_value=aux.size)[0] + vals = aux.at[val_indices].get(mode='fill', fill_value=0) + else: + vals = zeros(size, aux.dtype) if fill_value is None: vals = where(arange(len(vals)) < num_results, vals, vals.max()) vals = where(arange(len(vals)) < num_results, vals, vals.min()) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index d5efe1e03f31..860b3358d2ef 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -18,7 +18,7 @@ import collections from collections.abc import Iterator import copy -from functools import partial +from functools import partial, wraps import inspect import io import itertools @@ -174,10 +174,25 @@ def arrays_with_overlapping_values(rng, shapes, dtypes, unique=False, overlap=0. else: vals = jtu.rand_default(rng)((total_size,), 'int32') offsets = [int(sum(sizes[:i]) * (1 - overlap)) for i in range(len(sizes))] - return [np.random.permutation(vals[offset: offset + size]).reshape(shape).astype(dtype) + return [rng.permutation(vals[offset: offset + size]).reshape(shape).astype(dtype) for (offset, size, shape, dtype) in zip(offsets, sizes, shapes, dtypes)] +def with_size_argument(fun): + @wraps(fun) + def wrapped(*args, size=None, fill_value=None, **kwargs): + result = fun(*args, **kwargs) + if size is None or size == len(result): + return result + elif size < len(result): + return result[:size] + else: + if fill_value is None: + fill_value = result.min() if result.size else 0 + return np.pad(result, (0, size - len(result)), constant_values=fill_value) + return wrapped + + class LaxBackedNumpyTests(jtu.JaxTestCase): """Tests for LAX-backed Numpy implementation.""" @@ -786,19 +801,22 @@ def jnp_fun(arg1, arg2): shape1=all_shapes, shape2=all_shapes, assume_unique=[False, True], + size=[None, 2, 5], + fill_value=[None, 99], overlap=[0.1, 0.5, 0.9], ) - def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique, overlap): + def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique, size, fill_value, overlap): args_maker = partial(arrays_with_overlapping_values, self.rng(), shapes=[shape1, shape2], dtypes=[dtype1, dtype2], overlap=overlap) - jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique) + jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique, + size=size, fill_value=fill_value) def np_fun(ar1, ar2): if assume_unique: # numpy requires 1D inputs when assume_unique is True. ar1 = np.ravel(ar1) ar2 = np.ravel(ar2) - return np.setxor1d(ar1, ar2, assume_unique) + return with_size_argument(np.setxor1d)(ar1, ar2, assume_unique, size=size, fill_value=fill_value) with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)