Skip to content

Commit

Permalink
gcd_lcm_docstring_added
Browse files Browse the repository at this point in the history
description_improved
  • Loading branch information
selamw1 committed Aug 27, 2024
1 parent 6ee4136 commit 76583c8
Showing 1 changed file with 70 additions and 2 deletions.
72 changes: 70 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9166,9 +9166,43 @@ def _gcd_body_fn(xs: tuple[Array, Array]) -> tuple[Array, Array]:
where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0)))
return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2))

@util.implements(np.gcd, module='numpy')
@jit
def gcd(x1: ArrayLike, x2: ArrayLike) -> Array:
"""Compute the greatest common divisor of two arrays.
JAX implementation of :func:`numpy.gcd`.
Args:
x1: First input array. The elements must have integer dtype.
x2: Second input array. The elements must have integer dtype.
Returns:
An array containing the greatest common divisors of the corresponding
elements from the absolute values of `x1` and `x2`.
See also:
- :func:`jax.numpy.lcm`: compute the least common multiple of two arrays.
Examples:
Scalar inputs:
>>> jnp.gcd(12, 18)
Array(6, dtype=int32, weak_type=True)
Array inputs:
>>> x1 = jnp.array([12, 18, 24])
>>> x2 = jnp.array([5, 10, 15])
>>> jnp.gcd(x1, x2)
Array([1, 2, 3], dtype=int32)
Broadcasting:
>>> x1 = jnp.array([12])
>>> x2 = jnp.array([6, 9, 12])
>>> jnp.gcd(x1, x2)
Array([ 6, 3, 12], dtype=int32)
"""
util.check_arraylike("gcd", x1, x2)
x1, x2 = util.promote_dtypes(x1, x2)
if not issubdtype(_dtype(x1), integer):
Expand All @@ -9178,9 +9212,43 @@ def gcd(x1: ArrayLike, x2: ArrayLike) -> Array:
return gcd


@util.implements(np.lcm, module='numpy')
@jit
def lcm(x1: ArrayLike, x2: ArrayLike) -> Array:
"""Compute the least common multiple of two arrays.
JAX implementation of :func:`numpy.lcm`.
Args:
x1: First input array. The elements must have integer dtype.
x2: Second input array. The elements must have integer dtype.
Returns:
An array containing the least common multiple of the corresponding
elements from the absolute values of `x1` and `x2`.
See also:
- :func:`jax.numpy.gcd`: compute the greatest common divisor of two arrays.
Examples:
Scalar inputs:
>>> jnp.lcm(12, 18)
Array(36, dtype=int32, weak_type=True)
Array inputs:
>>> x1 = jnp.array([12, 18, 24])
>>> x2 = jnp.array([5, 10, 15])
>>> jnp.lcm(x1, x2)
Array([ 60, 90, 120], dtype=int32)
Broadcasting:
>>> x1 = jnp.array([12])
>>> x2 = jnp.array([6, 9, 12])
>>> jnp.lcm(x1, x2)
Array([12, 36, 12], dtype=int32)
"""
util.check_arraylike("lcm", x1, x2)
x1, x2 = util.promote_dtypes(x1, x2)
x1, x2 = ufuncs.abs(x1), ufuncs.abs(x2)
Expand Down

0 comments on commit 76583c8

Please sign in to comment.