From 6072f97961c7535c8aed2b04f382b8ad33382444 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Thu, 26 Sep 2024 21:38:14 +0530 Subject: [PATCH] Raise ValueError when axis1==axis2 for jnp.trace --- jax/_src/numpy/lax_numpy.py | 4 ++++ tests/lax_numpy_test.py | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ca70ff35dd9f..70fea964c321 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -6702,6 +6702,10 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int util.check_arraylike("trace", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.trace is not supported.") + + if _canonicalize_axis(axis1, ndim(a)) == _canonicalize_axis(axis2, ndim(a)): + raise ValueError(f"axis1 and axis2 can not be same. axis1={axis1} and axis2={axis2}") + dtypes.check_user_dtype_supported(dtype, "trace") a_shape = shape(a) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index c6d56885a6a8..6f8167df9c29 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2725,6 +2725,11 @@ def np_fun(arg): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + def testTraceSameAxesError(self): + a = jnp.arange(1, 13).reshape(2, 3, 2) + with self.assertRaisesRegex(ValueError, r"axis1 and axis2 can not be same"): + jnp.trace(a, axis1=1, axis2=-2) + @jtu.sample_product( ashape=[(15,), (16,), (17,)], vshape=[(), (5,), (5, 5)],