diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 692bd3d49aaa..ae7ff06cc5c2 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -9166,9 +9166,43 @@ def _gcd_body_fn(xs: tuple[Array, Array]) -> tuple[Array, Array]: where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0))) return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2)) -@util.implements(np.gcd, module='numpy') @jit def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: + """Compute the greatest common divisor of two arrays. + + JAX implementation of :func:`numpy.gcd`. + + Args: + x1: First input array. The elements must have integer dtype. + x2: Second input array. The elements must have integer dtype. + + Returns: + An array containing the greatest common divisors of the corresponding + elements from the absolute values of `x1` and `x2`. + + See also: + - :func:`jax.numpy.lcm`: compute the least common multiple of two arrays. + + Examples: + Scalar inputs: + + >>> jnp.gcd(12, 18) + Array(6, dtype=int32, weak_type=True) + + Array inputs: + + >>> x1 = jnp.array([12, 18, 24]) + >>> x2 = jnp.array([5, 10, 15]) + >>> jnp.gcd(x1, x2) + Array([1, 2, 3], dtype=int32) + + Broadcasting: + + >>> x1 = jnp.array([12]) + >>> x2 = jnp.array([6, 9, 12]) + >>> jnp.gcd(x1, x2) + Array([ 6, 3, 12], dtype=int32) + """ util.check_arraylike("gcd", x1, x2) x1, x2 = util.promote_dtypes(x1, x2) if not issubdtype(_dtype(x1), integer): @@ -9178,9 +9212,43 @@ def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: return gcd -@util.implements(np.lcm, module='numpy') @jit def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: + """Compute the least common multiple of two arrays. + + JAX implementation of :func:`numpy.lcm`. + + Args: + x1: First input array. The elements must have integer dtype. + x2: Second input array. The elements must have integer dtype. + + Returns: + An array containing the least common multiple of the corresponding + elements from the absolute values of `x1` and `x2`. + + See also: + - :func:`jax.numpy.gcd`: compute the greatest common divisor of two arrays. + + Examples: + Scalar inputs: + + >>> jnp.lcm(12, 18) + Array(36, dtype=int32, weak_type=True) + + Array inputs: + + >>> x1 = jnp.array([12, 18, 24]) + >>> x2 = jnp.array([5, 10, 15]) + >>> jnp.lcm(x1, x2) + Array([ 60, 90, 120], dtype=int32) + + Broadcasting: + + >>> x1 = jnp.array([12]) + >>> x2 = jnp.array([6, 9, 12]) + >>> jnp.lcm(x1, x2) + Array([12, 36, 12], dtype=int32) + """ util.check_arraylike("lcm", x1, x2) x1, x2 = util.promote_dtypes(x1, x2) x1, x2 = ufuncs.abs(x1), ufuncs.abs(x2)