From c17ae0fd985ddd3ef27674086d616b4b7d7a3b67 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Fri, 27 Sep 2024 23:03:11 +0530 Subject: [PATCH] Improve docs for jax.numpy: arcsinh, arccosh and arctanh --- jax/_src/numpy/ufuncs.py | 144 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 138 insertions(+), 6 deletions(-) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index dc265b8e87e1..b2b7f5e710c3 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -604,7 +604,7 @@ def arcsin(x: ArrayLike, /) -> Array: Note: - ``jnp.arcsin`` returns ``nan`` when ``x`` is real-valued and not in the closed interval ``[-1, 1]``. - - ``jnp.arcsin`` follows the branch cut convention of :func:`numpy.arcsin` for + - ``jnp.arcsin`` follows the branch cut convention of :obj:`numpy.arcsin` for complex inputs. See also: @@ -645,7 +645,7 @@ def arccos(x: ArrayLike, /) -> Array: Note: - ``jnp.arccos`` returns ``nan`` when ``x`` is real-valued and not in the closed interval ``[-1, 1]``. - - ``jnp.arccos`` follows the branch cut convention of :func:`numpy.arccos` for + - ``jnp.arccos`` follows the branch cut convention of :obj:`numpy.arccos` for complex inputs. See also: @@ -685,7 +685,7 @@ def arctan(x: ArrayLike, /) -> Array: in radians in the range ``[-pi/2, pi/2]``, promoting to inexact dtype. Note: - ``jnp.arctan`` follows the branch cut convention of :func:`numpy.arctan` for + ``jnp.arctan`` follows the branch cut convention of :obj:`numpy.arctan` for complex inputs. See also: @@ -817,14 +817,103 @@ def cosh(x: ArrayLike, /) -> Array: """ return lax.cosh(*promote_args_inexact('cosh', x)) -@implements(np.arcsinh, module='numpy') + @partial(jit, inline=True) def arcsinh(x: ArrayLike, /) -> Array: + r"""Calculate element-wise inverse of hyperbolic sine of input. + + JAX implementation of :obj:`numpy.arcsinh`. + + The inverse of hyperbolic sine is defined by: + + .. math:: + + arcsinh(x) = \ln(x + \sqrt{1 + x^2}) + + Args: + x: input array or scalar. + + Returns: + An array of same shape as ``x`` containing the inverse of hyperbolic sine of + each element of ``x``, promoting to inexact dtype. + + Note: + - ``jnp.arcsinh`` returns ``nan`` for values outside the range ``(-inf, inf)``. + - ``jnp.arcsinh`` follows the branch cut convention of :obj:`numpy.arcsinh` + for complex inputs. + + See also: + - :func:`jax.numpy.sinh`: Computes the element-wise hyperbolic sine of the input. + - :func:`jax.numpy.arccosh`: Computes the element-wise inverse of hyperbolic + cosine of the input. + - :func:`jax.numpy.arctanh`: Computes the element-wise inverse of hyperbolic + tangent of the input. + + Examples: + >>> x = jnp.array([[-2, 3, 1], + ... [4, 9, -5]]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arcsinh(x) + Array([[-1.444, 1.818, 0.881], + [ 2.095, 2.893, -2.312]], dtype=float32) + + For complex-valued inputs: + + >>> x1 = jnp.array([4-3j, 2j]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arcsinh(x1) + Array([2.306-0.634j, 1.317+1.571j], dtype=complex64) + """ return lax.asinh(*promote_args_inexact('arcsinh', x)) -@implements(np.arccosh, module='numpy') + @jit def arccosh(x: ArrayLike, /) -> Array: + r"""Calculate element-wise inverse of hyperbolic cosine of input. + + JAX implementation of :obj:`numpy.arccosh`. + + The inverse of hyperbolic cosine is defined by: + + .. math:: + + arccosh(x) = \ln(x + \sqrt{x^2 - 1}) + + Args: + x: input array or scalar. + + Returns: + An array of same shape as ``x`` containing the inverse of hyperbolic cosine + of each element of ``x``, promoting to inexact dtype. + + Note: + - ``jnp.arccosh`` returns ``nan`` for real-values in the range ``[-inf, 1)``. + - ``jnp.arccosh`` follows the branch cut convention of :obj:`numpy.arccosh` + for complex inputs. + + See also: + - :func:`jax.numpy.cosh`: Computes the element-wise hyperbolic cosine of the + input. + - :func:`jax.numpy.arcsinh`: Computes the element-wise inverse of hyperbolic + sine of the input. + - :func:`jax.numpy.arctanh`: Computes the element-wise inverse of hyperbolic + tangent of the input. + + Examples: + >>> x = jnp.array([[1, 3, -4], + ... [-5, 2, 7]]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arccosh(x) + Array([[0. , 1.763, nan], + [ nan, 1.317, 2.634]], dtype=float32) + + For complex-valued input: + + >>> x1 = jnp.array([-jnp.inf+0j, 1+2j, -5+0j]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arccosh(x1) + Array([ inf+3.142j, 1.529+1.144j, 2.292+3.142j], dtype=complex64) + """ # Note: arccosh is multi-valued for complex input, and lax.acosh # uses a different convention than np.arccosh. result = lax.acosh(*promote_args_inexact("arccosh", x)) @@ -885,9 +974,52 @@ def tanh(x: ArrayLike, /) -> Array: """ return lax.tanh(*promote_args_inexact('tanh', x)) -@implements(np.arctanh, module='numpy') + @partial(jit, inline=True) def arctanh(x: ArrayLike, /) -> Array: + r"""Calculate element-wise inverse of hyperbolic tangent of input. + + JAX implementation of :obj:`numpy.arctanh`. + + The inverse of hyperbolic tangent is defined by: + + .. math:: + + arctanh(x) = \frac{1}{2} [\ln(1 + x) - \ln(1 - x)] + + Args: + x: input array or scalar. + + Returns: + An array of same shape as ``x`` containing the inverse of hyperbolic tangent + of each element of ``x``, promoting to inexact dtype. + + Note: + - ``jnp.arctanh`` returns ``nan`` for real-values outside the range ``[-1, 1]``. + - ``jnp.arctanh`` follows the branch cut convention of :obj:`numpy.arctanh` + for complex inputs. + + See also: + - :func:`jax.numpy.tanh`: Computes the element-wise hyperbolic tangent of the + input. + - :func:`jax.numpy.arcsinh`: Computes the element-wise inverse of hyperbolic + sine of the input. + - :func:`jax.numpy.arccosh`: Computes the element-wise inverse of hyperbolic + cosine of the input. + + Examples: + >>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arctanh(x) + Array([ nan, -inf, -0.549, 0. , 0.549, inf, nan], dtype=float32) + + For complex-valued input: + + >>> x1 = jnp.array([-2+0j, 3+0j, 4-1j]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arctanh(x1) + Array([-0.549+1.571j, 0.347+1.571j, 0.239-1.509j], dtype=complex64) + """ return lax.atanh(*promote_args_inexact('arctanh', x))