Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.setxor1d: add support for static size argument #23020

Merged
merged 1 commit into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 65 additions & 14 deletions jax/_src/numpy/setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool) -> Array:
return (ar1_flat[:, None] == ar2_flat[None, :]).any(-1)


def _concat_unique(arr1: Array, arr2: Array) -> tuple[Array, Array]:
"""Utility to concatenate the unique values from two arrays."""
arr1, arr2 = ravel(arr1), ravel(arr2)
arr1, num_unique1 = _unique(arr1, axis=0, size=arr1.size, return_true_size=True)
arr2, num_unique2 = _unique(arr2, axis=0, size=arr2.size, return_true_size=True)
arr = zeros(arr1.size + arr2.size, dtype=dtypes.result_type(arr1, arr2))
arr = lax.dynamic_update_slice(arr, arr1, (0,))
arr = lax.dynamic_update_slice(arr, arr2, (num_unique1,))
return arr, num_unique1 + num_unique2


def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
*, size: int | None = None, fill_value: ArrayLike | None = None) -> Array:
"""Compute the set difference of two 1D arrays.
Expand Down Expand Up @@ -220,7 +231,39 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike,
return cast(Array, out)


def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Array:
@partial(jit, static_argnames=['assume_unique', 'size'])
def _setxor1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, *,
assume_unique: bool, size: int, ) -> Array:
# Ensured by caller
assert arr1.ndim == arr2.ndim == 1
assert arr1.dtype == arr2.dtype

if assume_unique:
arr = concatenate([arr1, arr2])
aux = sort(concatenate([arr1, arr2]))
flag = concatenate((bool(aux.size), aux[1:] != aux[:-1], True), axis=None)
else:
arr, num_unique = _concat_unique(arr1, arr2)
mask = arange(arr.size + 1) < num_unique + 1
_, aux = lax.sort([~mask[1:], arr], is_stable=True, num_keys=2)
flag = mask & concatenate((bool(aux.size), aux[1:] != aux[:-1], False),
axis=None).at[num_unique].set(True)
aux_mask = flag[1:] & flag[:-1]
num_results = aux_mask.sum()
if aux.size:
indices = nonzero(aux_mask, size=size, fill_value=len(aux))[0]
vals = aux.at[indices].get(mode='fill', fill_value=0)
else:
vals = zeros(size, aux.dtype)
if fill_value is None:
vals = where(arange(len(vals)) < num_results, vals, vals.max())
return where(arange(len(vals)) < num_results, vals, vals.min())
else:
return where(arange(len(vals)) < num_results, vals, fill_value)


def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *,
size: int | None = None, fill_value: ArrayLike | None = None) -> Array:
"""Compute the set-wise xor of elements in two arrays.

JAX implementation of :func:`numpy.setxor1d`.
Expand All @@ -234,6 +277,12 @@ def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Arr
assume_unique: if True, assume the input arrays contain unique values. This allows
a more efficient implementation, but if ``assume_unique`` is True and the input
arrays contain duplicates, the behavior is undefined. default: False.
size: if specified, return only the first ``size`` sorted elements. If there are fewer
elements than ``size`` indicates, the return value will be padded with ``fill_value``,
and returned indices will be padded with an out-of-bound index.
fill_value: when ``size`` is specified and there are fewer than the indicated number of
elements, fill the remaining entries ``fill_value``. Defaults to the smallest value
in the xor result.

Returns:
An array of values that are found in exactly one of the input arrays.
Expand All @@ -250,22 +299,21 @@ def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Arr
Array([1, 2, 5, 6], dtype=int32)
"""
check_arraylike("setxor1d", ar1, ar2)
ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()")
ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()")
arr1, arr2 = promote_dtypes(ravel(ar1), ravel(ar2))
del ar1, ar2

ar1 = ravel(ar1)
ar2 = ravel(ar2)
if size is not None:
return _setxor1d_size(arr1, arr2, fill_value=fill_value,
assume_unique=assume_unique, size=size)

if not assume_unique:
ar1 = unique(ar1)
ar2 = unique(ar2)

aux = concatenate((ar1, ar2))
arr1 = unique(arr1)
arr2 = unique(arr2)
aux = concatenate((arr1, arr2))
if aux.size == 0:
return aux

aux = sort(aux)
flag = concatenate((array([True]), aux[1:] != aux[:-1], array([True])))
flag = concatenate((True, aux[1:] != aux[:-1], True), axis=None)
return aux[flag[1:] & flag[:-1]]


Expand Down Expand Up @@ -312,7 +360,7 @@ def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, as
arr1, ind1, num_unique1 = _unique(arr1, 0, size=arr1.size, return_index=True, return_true_size=True, fill_value=0)
arr2, ind2, num_unique2 = _unique(arr2, 0, size=arr2.size, return_index=True, return_true_size=True, fill_value=0)
arr = zeros(arr1.size + arr2.size, dtype=dtypes.result_type(arr1, arr2))
arr = arr.at[:arr1.size].set(arr1)
arr = lax.dynamic_update_slice(arr, arr1, (0,))
arr = lax.dynamic_update_slice(arr, arr2, (num_unique1,))
mask = arange(arr.size) < num_unique1 + num_unique2
_, aux, aux_sort_indices = lax.sort([~mask, arr, arange(arr.size)], is_stable=True, num_keys=2)
Expand All @@ -326,8 +374,11 @@ def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, as
# and vals[num_results:] contains the appropriate fill_value.
aux_mask = (aux[1:] == aux[:-1]) & mask[1:]
num_results = aux_mask.sum()
val_indices = nonzero(aux_mask, size=size, fill_value=aux.size)[0]
vals = aux.at[val_indices].get(mode='fill', fill_value=0)
if aux.size:
val_indices = nonzero(aux_mask, size=size, fill_value=aux.size)[0]
vals = aux.at[val_indices].get(mode='fill', fill_value=0)
else:
vals = zeros(size, aux.dtype)
if fill_value is None:
vals = where(arange(len(vals)) < num_results, vals, vals.max())
vals = where(arange(len(vals)) < num_results, vals, vals.min())
Expand Down
28 changes: 23 additions & 5 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import collections
from collections.abc import Iterator
import copy
from functools import partial
from functools import partial, wraps
import inspect
import io
import itertools
Expand Down Expand Up @@ -174,10 +174,25 @@ def arrays_with_overlapping_values(rng, shapes, dtypes, unique=False, overlap=0.
else:
vals = jtu.rand_default(rng)((total_size,), 'int32')
offsets = [int(sum(sizes[:i]) * (1 - overlap)) for i in range(len(sizes))]
return [np.random.permutation(vals[offset: offset + size]).reshape(shape).astype(dtype)
return [rng.permutation(vals[offset: offset + size]).reshape(shape).astype(dtype)
for (offset, size, shape, dtype) in zip(offsets, sizes, shapes, dtypes)]


def with_size_argument(fun):
@wraps(fun)
def wrapped(*args, size=None, fill_value=None, **kwargs):
result = fun(*args, **kwargs)
if size is None or size == len(result):
return result
elif size < len(result):
return result[:size]
else:
if fill_value is None:
fill_value = result.min() if result.size else 0
return np.pad(result, (0, size - len(result)), constant_values=fill_value)
return wrapped


class LaxBackedNumpyTests(jtu.JaxTestCase):
"""Tests for LAX-backed Numpy implementation."""

Expand Down Expand Up @@ -786,19 +801,22 @@ def jnp_fun(arg1, arg2):
shape1=all_shapes,
shape2=all_shapes,
assume_unique=[False, True],
size=[None, 2, 5],
fill_value=[None, 99],
overlap=[0.1, 0.5, 0.9],
)
def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique, overlap):
def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique, size, fill_value, overlap):
args_maker = partial(arrays_with_overlapping_values, self.rng(),
shapes=[shape1, shape2], dtypes=[dtype1, dtype2],
overlap=overlap)
jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique)
jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique,
size=size, fill_value=fill_value)
def np_fun(ar1, ar2):
if assume_unique:
# numpy requires 1D inputs when assume_unique is True.
ar1 = np.ravel(ar1)
ar2 = np.ravel(ar2)
return np.setxor1d(ar1, ar2, assume_unique)
return with_size_argument(np.setxor1d)(ar1, ar2, assume_unique, size=size, fill_value=fill_value)
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)

Expand Down