Skip to content

Commit

Permalink
Raise ValueError when axis1==axis2 for jnp.trace
Browse files Browse the repository at this point in the history
  • Loading branch information
rajasekharporeddy committed Sep 26, 2024
1 parent f6fdfb4 commit cdb8b33
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
6 changes: 6 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6702,6 +6702,12 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int
util.check_arraylike("trace", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")

axis1 = _canonicalize_axis(axis1, a.ndim)
axis2 = _canonicalize_axis(axis2, a.ndim)
if axis1 == axis2:
raise ValueError(f"axis1 and axis2 can not be same. axis1={axis1} and axis2={axis2}")

dtypes.check_user_dtype_supported(dtype, "trace")

a_shape = shape(a)
Expand Down
17 changes: 17 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2725,6 +2725,23 @@ def np_fun(arg):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
[dict(shape=shape, axis1=axis1, axis2=axis2)
for shape in [shape for shape in all_shapes if len(shape) >= 2]
for axis1 in range(-len(shape), len(shape))
for axis2 in range(-len(shape), len(shape))
if (axis1 % len(shape)) == (axis2 % len(shape))
],
dtype=default_dtypes,
out_dtype=[None] + number_dtypes,
offset=list(range(-4, 4)),
)
def testTraceSameAxesError(self, shape, dtype, out_dtype, offset, axis1, axis2):
rng = jtu.rand_default(self.rng())
arg = rng(shape, dtype)
with self.assertRaisesRegex(ValueError, r"axis1 and axis2 can not be same"):
jnp.trace(arg, offset, axis1, axis2, out_dtype)

@jtu.sample_product(
ashape=[(15,), (16,), (17,)],
vshape=[(), (5,), (5, 5)],
Expand Down

0 comments on commit cdb8b33

Please sign in to comment.