Skip to content

Commit

Permalink
Merge pull request #9184 from jakevdp:unique-nan
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 422287302
  • Loading branch information
jax authors committed Jan 17, 2022
2 parents bebe984 + bd157cf commit 6411f8a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
Previously negative ``NaN`` values were sorted to the front of the array, and ``NaN``
values with different internal bit representations were not treated as equivalent, and
were sorted according to those bit patterns ({jax-issue}`#9178`).
* {func}`jax.numpy.unique` now treats ``NaN`` values in the same way as `np.unique` in
NumPy versions 1.21 and newer: at most one ``NaN`` value will appear in the uniquified
output ({jax-issue}`9184`).

* Bug fixes:
* host_callback now supports ad_checkpoint.checkpoint ({jax-issue}`#8907`).
Expand Down
13 changes: 12 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5602,6 +5602,11 @@ def replace(tup, val):
@partial(jit, static_argnums=1)
def _unique_sorted_mask(ar, axis):
aux = moveaxis(ar, axis, 0)
if issubdtype(aux.dtype, np.complexfloating):
# Work around issue in sorting of complex numbers with Nan only in the
# imaginary component. This can be removed if sorting in this situation
# is fixed to match numpy.
aux = where(isnan(aux), lax._const(aux, nan), aux)
size, *out_shape = aux.shape
if _prod(out_shape) == 0:
size = 1
Expand All @@ -5610,7 +5615,13 @@ def _unique_sorted_mask(ar, axis):
perm = lexsort(aux.reshape(size, _prod(out_shape)).T[::-1])
aux = aux[perm]
if aux.size:
mask = ones(size, dtype=bool).at[1:].set(any(aux[1:] != aux[:-1], tuple(range(1, aux.ndim))))
if issubdtype(aux.dtype, inexact):
# This is appropriate for both float and complex due to the documented behavior of np.unique:
# See https://github.com/numpy/numpy/blob/v1.22.0/numpy/lib/arraysetops.py#L212-L220
neq = lambda x, y: lax.ne(x, y) & ~(isnan(x) & isnan(y))
else:
neq = lax.ne
mask = ones(size, dtype=bool).at[1:].set(any(neq(aux[1:], aux[:-1]), tuple(range(1, aux.ndim))))
else:
mask = zeros(size, dtype=bool)
return aux, mask, perm
Expand Down
23 changes: 23 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2479,6 +2479,29 @@ def np_fun(x, fill_value=fill_value):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@unittest.skipIf(numpy_version < (1, 21), "Numpy < 1.21 does not properly handle NaN values in unique.")
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{dtype.__name__}", "dtype": dtype}
for dtype in inexact_dtypes))
def testUniqueNans(self, dtype):
def args_maker():
x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan]
if np.issubdtype(dtype, np.complexfloating):
x = [complex(i, j) for i, j in itertools.product(x, repeat=2)]
return [np.array(x, dtype=dtype)]

kwds = dict(return_index=True, return_inverse=True, return_counts=True)
jnp_fun = partial(jnp.unique, **kwds)
def np_fun(x):
dtype = x.dtype
# numpy unique fails for bfloat16 NaNs, so we cast to float64
if x.dtype == jnp.bfloat16:
x = x.astype('float64')
u, *rest = np.unique(x, **kwds)
return (u.astype(dtype), *rest)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)


@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_fixed_size={}".format(fixed_size),
"fixed_size": fixed_size}
Expand Down

0 comments on commit 6411f8a

Please sign in to comment.