diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ab53d9e736dd..387b3b2a51a7 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6812,6 +6812,10 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int util.check_arraylike("trace", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.trace is not supported.") + + if _canonicalize_axis(axis1, ndim(a)) == _canonicalize_axis(axis2, ndim(a)): + raise ValueError(f"axis1 and axis2 can not be same. axis1={axis1} and axis2={axis2}") + dtypes.check_user_dtype_supported(dtype, "trace") a_shape = shape(a) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index c6d56885a6a8..6f8167df9c29 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2725,6 +2725,11 @@ def np_fun(arg): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + def testTraceSameAxesError(self): + a = jnp.arange(1, 13).reshape(2, 3, 2) + with self.assertRaisesRegex(ValueError, r"axis1 and axis2 can not be same"): + jnp.trace(a, axis1=1, axis2=-2) + @jtu.sample_product( ashape=[(15,), (16,), (17,)], vshape=[(), (5,), (5, 5)],