From d961e13733bd4032de1065a1d7699a0ba6d848d0 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Wed, 11 Sep 2024 22:51:30 +0530 Subject: [PATCH] Add comment --- CHANGELOG.md | 4 ++-- jax/_src/numpy/lax_numpy.py | 10 +++++----- tests/lax_numpy_test.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c45f7c972b1..31393b31441e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -64,8 +64,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Passing a DLPack capsule to {func}`jax.dlpack.from_dlpack` is deprecated. The argument to {func}`jax.dlpack.from_dlpack` should be an array from another framework that implements the ``__dlpack__`` protocol. - * Passing `non-arraylike` inputs to {func}`jax.numpy.trim_zeros` is deprecated - and now raises a {obj}`DeprecationWarning`. It currently is converted in to an + * Passing `non-array` inputs to {func}`jax.numpy.trim_zeros` is deprecated and + now raises a {obj}`DeprecationWarning`. It currently is converted in to an array, and in the future will raise a {obj}`TypeError`. Also, passing `NdArrays` with `ndim != 1` to {func}`jax.numpy.trim_zeros` is deprecated and now raises a {obj}`DeprecationWarning`. It's current behavior is inconsistent with diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index a6f0ba680552..f2603406441c 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6397,23 +6397,23 @@ def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array: >>> jnp.trim_zeros(x) Array([2, 0, 1, 4, 3], dtype=int32) """ + # Non-array inputs are deprecated 2024-09-11 util.check_arraylike("trim_zeros", filt, emit_warning=True) core.concrete_or_error(None, filt, "Error arose in the `filt` argument of trim_zeros()") filt_arr = jax.numpy.asarray(filt) del filt if filt_arr.ndim != 1: - # Added on Sep 11 2024 + # Added on 2024-09-11 if deprecations.is_accelerated("jax-numpy-trimzeros-not-1d-array"): raise TypeError(f"'filt' must be 1-D array, but received {filt_arr.ndim}-D array.") - deprecations.warn( - "jax-numpy-trimzeros-not-1d-array", + warnings.warn( "Passing arrays with ndim != 1 to jnp.trim_zeros() is deprecated. Currently, it " "works with Arrays having ndim != 1. In the future this will result in an error.", - stacklevel=2) + DeprecationWarning, stacklevel=2) nz = (filt_arr == 0) if reductions.all(nz): - return empty(0, _dtype(filt_arr)) + return empty(0, filt_arr.dtype) start = argmin(nz) if 'f' in trim.lower() else 0 end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 return filt_arr[start:len(filt_arr) - end] diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 9bb0c49563a2..484968b78802 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1478,7 +1478,7 @@ def testTrimZeros(self, a_shape, dtype, trim): jnp_fun = lambda arg1: jnp.trim_zeros(arg1, trim) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) - def testTrimZerosNdimArrayInput(self): + def testTrimZerosNotOneDArray(self): # TODO: make this an error after the deprecation period. with self.assertWarnsRegex(DeprecationWarning, r"Passing arrays with ndim != 1 to jnp.trim_zeros\(\)"):