Skip to content

Commit

Permalink
jnp.setxor1d: add support for static size argument.
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 12, 2024
1 parent 24f5f69 commit b14b197
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 18 deletions.
79 changes: 65 additions & 14 deletions jax/_src/numpy/setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool) -> Array:
return (ar1_flat[:, None] == ar2_flat[None, :]).any(-1)


def _concat_unique(arr1: Array, arr2: Array) -> tuple[Array, Array]:
"""Utility to concatenate the unique values from two arrays."""
arr1, arr2 = ravel(arr1), ravel(arr2)
arr1, num_unique1 = _unique(arr1, axis=0, size=arr1.size, return_true_size=True)
arr2, num_unique2 = _unique(arr2, axis=0, size=arr2.size, return_true_size=True)
arr = zeros(arr1.size + arr2.size, dtype=dtypes.result_type(arr1, arr2))
arr = lax.dynamic_update_slice(arr, arr1, (0,))
arr = lax.dynamic_update_slice(arr, arr2, (num_unique1,))
return arr, num_unique1 + num_unique2


def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False,
*, size: int | None = None, fill_value: ArrayLike | None = None) -> Array:
"""Compute the set difference of two 1D arrays.
Expand Down Expand Up @@ -220,7 +231,39 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike,
return cast(Array, out)


def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Array:
@partial(jit, static_argnames=['assume_unique', 'size'])
def _setxor1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, *,
assume_unique: bool, size: int, ) -> Array:
# Ensured by caller
assert arr1.ndim == arr2.ndim == 1
assert arr1.dtype == arr2.dtype

if assume_unique:
arr = concatenate([arr1, arr2])
aux = sort(concatenate([arr1, arr2]))
flag = concatenate((bool(aux.size), aux[1:] != aux[:-1], True), axis=None)
else:
arr, num_unique = _concat_unique(arr1, arr2)
mask = arange(arr.size + 1) < num_unique + 1
_, aux = lax.sort([~mask[1:], arr], is_stable=True, num_keys=2)
flag = mask & concatenate((bool(aux.size), aux[1:] != aux[:-1], False),
axis=None).at[num_unique].set(True)
aux_mask = flag[1:] & flag[:-1]
num_results = aux_mask.sum()
if aux.size:
indices = nonzero(aux_mask, size=size, fill_value=len(aux))[0]
vals = aux.at[indices].get(mode='fill', fill_value=0)
else:
vals = zeros(size, aux.dtype)
if fill_value is None:
vals = where(arange(len(vals)) < num_results, vals, vals.max())
return where(arange(len(vals)) < num_results, vals, vals.min())
else:
return where(arange(len(vals)) < num_results, vals, fill_value)


def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *,
size: int | None = None, fill_value: ArrayLike | None = None) -> Array:
"""Compute the set-wise xor of elements in two arrays.
JAX implementation of :func:`numpy.setxor1d`.
Expand All @@ -234,6 +277,12 @@ def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Arr
assume_unique: if True, assume the input arrays contain unique values. This allows
a more efficient implementation, but if ``assume_unique`` is True and the input
arrays contain duplicates, the behavior is undefined. default: False.
size: if specified, return only the first ``size`` sorted elements. If there are fewer
elements than ``size`` indicates, the return value will be padded with ``fill_value``,
and returned indices will be padded with an out-of-bound index.
fill_value: when ``size`` is specified and there are fewer than the indicated number of
elements, fill the remaining entries ``fill_value``. Defaults to the smallest value
in the xor result.
Returns:
An array of values that are found in exactly one of the input arrays.
Expand All @@ -250,22 +299,21 @@ def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Arr
Array([1, 2, 5, 6], dtype=int32)
"""
check_arraylike("setxor1d", ar1, ar2)
ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()")
ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()")
arr1, arr2 = promote_dtypes(ravel(ar1), ravel(ar2))
del ar1, ar2

ar1 = ravel(ar1)
ar2 = ravel(ar2)
if size is not None:
return _setxor1d_size(arr1, arr2, fill_value=fill_value,
assume_unique=assume_unique, size=size)

if not assume_unique:
ar1 = unique(ar1)
ar2 = unique(ar2)

aux = concatenate((ar1, ar2))
arr1 = unique(arr1)
arr2 = unique(arr2)
aux = concatenate((arr1, arr2))
if aux.size == 0:
return aux

aux = sort(aux)
flag = concatenate((array([True]), aux[1:] != aux[:-1], array([True])))
flag = concatenate((True, aux[1:] != aux[:-1], True), axis=None)
return aux[flag[1:] & flag[:-1]]


Expand Down Expand Up @@ -312,7 +360,7 @@ def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, as
arr1, ind1, num_unique1 = _unique(arr1, 0, size=arr1.size, return_index=True, return_true_size=True, fill_value=0)
arr2, ind2, num_unique2 = _unique(arr2, 0, size=arr2.size, return_index=True, return_true_size=True, fill_value=0)
arr = zeros(arr1.size + arr2.size, dtype=dtypes.result_type(arr1, arr2))
arr = arr.at[:arr1.size].set(arr1)
arr = lax.dynamic_update_slice(arr, arr1, (0,))
arr = lax.dynamic_update_slice(arr, arr2, (num_unique1,))
mask = arange(arr.size) < num_unique1 + num_unique2
_, aux, aux_sort_indices = lax.sort([~mask, arr, arange(arr.size)], is_stable=True, num_keys=2)
Expand All @@ -326,8 +374,11 @@ def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, as
# and vals[num_results:] contains the appropriate fill_value.
aux_mask = (aux[1:] == aux[:-1]) & mask[1:]
num_results = aux_mask.sum()
val_indices = nonzero(aux_mask, size=size, fill_value=aux.size)[0]
vals = aux.at[val_indices].get(mode='fill', fill_value=0)
if aux.size:
val_indices = nonzero(aux_mask, size=size, fill_value=aux.size)[0]
vals = aux.at[val_indices].get(mode='fill', fill_value=0)
else:
vals = zeros(size, aux.dtype)
if fill_value is None:
vals = where(arange(len(vals)) < num_results, vals, vals.max())
vals = where(arange(len(vals)) < num_results, vals, vals.min())
Expand Down
26 changes: 22 additions & 4 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import collections
from collections.abc import Iterator
import copy
from functools import partial
from functools import partial, wraps
import inspect
import io
import itertools
Expand Down Expand Up @@ -178,6 +178,21 @@ def arrays_with_overlapping_values(rng, shapes, dtypes, unique=False, overlap=0.
for (offset, size, shape, dtype) in zip(offsets, sizes, shapes, dtypes)]


def with_size_argument(fun):
@wraps(fun)
def wrapped(*args, size=None, fill_value=None, **kwargs):
result = fun(*args, **kwargs)
if size is None or size == len(result):
return result
elif size < len(result):
return result[:size]
else:
if fill_value is None:
fill_value = result.min() if result.size else 0
return np.pad(result, (0, size - len(result)), constant_values=fill_value)
return wrapped


class LaxBackedNumpyTests(jtu.JaxTestCase):
"""Tests for LAX-backed Numpy implementation."""

Expand Down Expand Up @@ -786,19 +801,22 @@ def jnp_fun(arg1, arg2):
shape1=all_shapes,
shape2=all_shapes,
assume_unique=[False, True],
size=[None, 2, 5],
fill_value=[None, 99],
overlap=[0.1, 0.5, 0.9],
)
def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique, overlap):
def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique, size, fill_value, overlap):
args_maker = partial(arrays_with_overlapping_values, self.rng(),
shapes=[shape1, shape2], dtypes=[dtype1, dtype2],
overlap=overlap)
jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique)
jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique,
size=size, fill_value=fill_value)
def np_fun(ar1, ar2):
if assume_unique:
# numpy requires 1D inputs when assume_unique is True.
ar1 = np.ravel(ar1)
ar2 = np.ravel(ar2)
return np.setxor1d(ar1, ar2, assume_unique)
return with_size_argument(np.setxor1d)(ar1, ar2, assume_unique, size=size, fill_value=fill_value)
with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)

Expand Down

0 comments on commit b14b197

Please sign in to comment.