Skip to content

Commit

Permalink
Improve docs for jax.numpy: arcsinh, arccosh and arctanh
Browse files Browse the repository at this point in the history
  • Loading branch information
rajasekharporeddy committed Sep 27, 2024
1 parent 5a1549c commit 526b1a7
Showing 1 changed file with 143 additions and 6 deletions.
149 changes: 143 additions & 6 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def arcsin(x: ArrayLike, /) -> Array:
Note:
- ``jnp.arcsin`` returns ``nan`` when ``x`` is real-valued and not in the closed
interval ``[-1, 1]``.
- ``jnp.arcsin`` follows the branch cut convention of :func:`numpy.arcsin` for
- ``jnp.arcsin`` follows the branch cut convention of :obj:`numpy.arcsin` for
complex inputs.
See also:
Expand Down Expand Up @@ -645,7 +645,7 @@ def arccos(x: ArrayLike, /) -> Array:
Note:
- ``jnp.arccos`` returns ``nan`` when ``x`` is real-valued and not in the closed
interval ``[-1, 1]``.
- ``jnp.arccos`` follows the branch cut convention of :func:`numpy.arccos` for
- ``jnp.arccos`` follows the branch cut convention of :obj:`numpy.arccos` for
complex inputs.
See also:
Expand Down Expand Up @@ -685,7 +685,7 @@ def arctan(x: ArrayLike, /) -> Array:
in radians in the range ``[-pi/2, pi/2]``, promoting to inexact dtype.
Note:
``jnp.arctan`` follows the branch cut convention of :func:`numpy.arctan` for
``jnp.arctan`` follows the branch cut convention of :obj:`numpy.arctan` for
complex inputs.
See also:
Expand Down Expand Up @@ -817,14 +817,109 @@ def cosh(x: ArrayLike, /) -> Array:
"""
return lax.cosh(*promote_args_inexact('cosh', x))

@implements(np.arcsinh, module='numpy')

@partial(jit, inline=True)
def arcsinh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise inverse of hyperbolic sine of input.
JAX implementation of :obj:`numpy.arcsinh`.
The inverse of hyperbolic sine is defined by:
.. math::
arcsinh(x) = \ln(x + \sqrt{1 + x^2})
Args:
x: input array or scalar.
Returns:
An array of same shape as ``x`` containing the inverse of hyperbolic sine of
each element of ``x``, promoting to inexact dtype.
Note:
``jnp.arcsinh`` follows the branch cut convention of :obj:`numpy.arcsinh` for
complex inputs.
See also:
- :func:`jax.numpy.sinh`: Computes the element-wise hyperbolic sine of the input.
- :func:`jax.numpy.arccosh`: Computes the element-wise inverse of hyperbolic
cosine of the input.
- :func:`jax.numpy.arctanh`: Computes the element-wise inverse of hyperbolic
tangent of the input.
Examples:
>>> x = jnp.array([[-2, 3, 1],
... [4, 9, -5]])
>>> with jnp.printoptions(precision=3, suppress=True):
... jnp.arcsinh(x)
Array([[-1.444, 1.818, 0.881],
[ 2.095, 2.893, -2.312]], dtype=float32)
>>> with jnp.printoptions(precision=3, suppress=True):
... -(1/1j) * jnp.arcsin(x / 1j)
Array([[-1.444+0.j, 1.818+0.j, 0.881+0.j],
[ 2.095+0.j, 2.893+0.j, -2.312+0.j]], dtype=complex64, weak_type=True)
For complex-valued inputs:
>>> x1 = jnp.array([4-3j, 2j])
>>> with jnp.printoptions(precision=3, suppress=True):
... jnp.arcsinh(x1)
Array([2.306-0.634j, 1.317+1.571j], dtype=complex64)
>>> with jnp.printoptions(precision=3, suppress=True):
... -(1/1j) * jnp.arcsin(x1 / 1j)
Array([ 2.306-0.634j, -1.317+1.571j], dtype=complex64)
"""
return lax.asinh(*promote_args_inexact('arcsinh', x))

@implements(np.arccosh, module='numpy')

@jit
def arccosh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise inverse of hyperbolic cosine of input.
JAX implementation of :obj:`numpy.arccosh`.
The inverse of hyperbolic cosine is defined by:
.. math::
arccosh(x) = \ln(x + \sqrt{x^2 - 1})
Args:
x: input array or scalar.
Returns:
An array of same shape as ``x`` containing the inverse of hyperbolic cosine
of each element of ``x``, promoting to inexact dtype.
Note:
- ``jnp.arccosh`` returns ``nan`` for real-values in the range ``[-inf, 1)``.
- ``jnp.arccosh`` follows the branch cut convention of :obj:`numpy.arccosh` for
complex inputs.
See also:
- :func:`jax.numpy.cosh`: Computes the element-wise hyperbolic cosine of the
input.
- :func:`jax.numpy.arcsinh`: Computes the element-wise inverse of hyperbolic
sine of the input.
- :func:`jax.numpy.arctanh`: Computes the element-wise inverse of hyperbolic
tangent of the input.
Examples:
>>> x = jnp.array([[1, 3, -4],
... [-5, 2, 7]])
>>> with jnp.printoptions(precision=3, suppress=True):
... jnp.arccosh(x)
Array([[0. , 1.763, nan],
[ nan, 1.317, 2.634]], dtype=float32)
For complex-valued input:
>>> x1 = jnp.array([-jnp.inf+0j, 1+2j, -5+0j])
>>> with jnp.printoptions(precision=3, suppress=True):
... jnp.arccosh(x1)
Array([ inf+3.142j, 1.529+1.144j, 2.292+3.142j], dtype=complex64)
"""
# Note: arccosh is multi-valued for complex input, and lax.acosh
# uses a different convention than np.arccosh.
result = lax.acosh(*promote_args_inexact("arccosh", x))
Expand Down Expand Up @@ -885,9 +980,51 @@ def tanh(x: ArrayLike, /) -> Array:
"""
return lax.tanh(*promote_args_inexact('tanh', x))

@implements(np.arctanh, module='numpy')

@partial(jit, inline=True)
def arctanh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise inverse of hyperbolic tangent of input.
JAX implementation of :obj:`numpy.arctanh`.
The inverse of hyperbolic tangent is defined by:
.. math::
arctanh(x) = \frac{1}{2} [\ln(1 + x) - \ln(1 - x)]
Args:
x: input array or scalar.
Returns:
An array of same shape as ``x`` containing the inverse of hyperbolic tangent
of each element of ``x``, promoting to inexact dtype.
Note:
- ``jnp.arctanh`` returns ``nan`` for real-values outside the range ``[-1, 1]``.
- ``jnp.arctanh`` follows the branch cut convention of :obj:`numpy.arctanh` for
complex inputs.
See also:
- :func:`jax.numpy.tanh`: Computes the element-wise hyperbolic tangent of the
input.
- :func:`jax.numpy.arcsinh`: Computes the element-wise inverse of hyperbolic
sine of the input.
- :func:`jax.numpy.arccosh`: Computes the element-wise inverse of hyperbolic
cosine of the input.
Examples:
>>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2])
>>> with jnp.printoptions(precision=3, suppress=True):
... jnp.arctanh(x)
Array([ nan, -inf, -0.549, 0. , 0.549, inf, nan], dtype=float32)
For complex-valued input:
>>> with jnp.printoptions(precision=3, suppress=True):
... jnp.arctanh(x1)
Array([-0.549+1.571j, 0.347+1.571j, 0.239-1.509j], dtype=complex64)
"""
return lax.atanh(*promote_args_inexact('arctanh', x))


Expand Down

0 comments on commit 526b1a7

Please sign in to comment.