Skip to content

Commit

Permalink
Add comment
Browse files Browse the repository at this point in the history
  • Loading branch information
rajasekharporeddy committed Sep 11, 2024
1 parent c36803d commit d961e13
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Passing a DLPack capsule to {func}`jax.dlpack.from_dlpack` is deprecated.
The argument to {func}`jax.dlpack.from_dlpack` should be an array from
another framework that implements the ``__dlpack__`` protocol.
* Passing `non-arraylike` inputs to {func}`jax.numpy.trim_zeros` is deprecated
and now raises a {obj}`DeprecationWarning`. It currently is converted in to an
* Passing `non-array` inputs to {func}`jax.numpy.trim_zeros` is deprecated and
now raises a {obj}`DeprecationWarning`. It currently is converted in to an
array, and in the future will raise a {obj}`TypeError`. Also, passing `NdArrays`
with `ndim != 1` to {func}`jax.numpy.trim_zeros` is deprecated and now raises
a {obj}`DeprecationWarning`. It's current behavior is inconsistent with
Expand Down
10 changes: 5 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6397,23 +6397,23 @@ def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array:
>>> jnp.trim_zeros(x)
Array([2, 0, 1, 4, 3], dtype=int32)
"""
# Non-array inputs are deprecated 2024-09-11
util.check_arraylike("trim_zeros", filt, emit_warning=True)
core.concrete_or_error(None, filt,
"Error arose in the `filt` argument of trim_zeros()")
filt_arr = jax.numpy.asarray(filt)
del filt
if filt_arr.ndim != 1:
# Added on Sep 11 2024
# Added on 2024-09-11
if deprecations.is_accelerated("jax-numpy-trimzeros-not-1d-array"):
raise TypeError(f"'filt' must be 1-D array, but received {filt_arr.ndim}-D array.")
deprecations.warn(
"jax-numpy-trimzeros-not-1d-array",
warnings.warn(
"Passing arrays with ndim != 1 to jnp.trim_zeros() is deprecated. Currently, it "
"works with Arrays having ndim != 1. In the future this will result in an error.",
stacklevel=2)
DeprecationWarning, stacklevel=2)
nz = (filt_arr == 0)
if reductions.all(nz):
return empty(0, _dtype(filt_arr))
return empty(0, filt_arr.dtype)
start = argmin(nz) if 'f' in trim.lower() else 0
end = argmin(nz[::-1]) if 'b' in trim.lower() else 0
return filt_arr[start:len(filt_arr) - end]
Expand Down
2 changes: 1 addition & 1 deletion tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,7 +1478,7 @@ def testTrimZeros(self, a_shape, dtype, trim):
jnp_fun = lambda arg1: jnp.trim_zeros(arg1, trim)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)

def testTrimZerosNdimArrayInput(self):
def testTrimZerosNotOneDArray(self):
# TODO: make this an error after the deprecation period.
with self.assertWarnsRegex(DeprecationWarning,
r"Passing arrays with ndim != 1 to jnp.trim_zeros\(\)"):
Expand Down

0 comments on commit d961e13

Please sign in to comment.