diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index fd0a9e502b1d..dc265b8e87e1 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -710,14 +710,111 @@ def arctan(x: ArrayLike, /) -> Array: """ return lax.atan(*promote_args_inexact('arctan', x)) -@implements(np.sinh, module='numpy') + @partial(jit, inline=True) def sinh(x: ArrayLike, /) -> Array: + r"""Calculate element-wise hyperbolic sine of input. + + JAX implementation of :obj:`numpy.sinh`. + + The hyperbolic sine is defined by: + + .. math:: + + sinh(x) = \frac{e^x - e^{-x}}{2} + + Args: + x: input array or scalar. + + Returns: + An array containing the hyperbolic sine of each element of ``x``, promoting + to inexact dtype. + + Note: + ``jnp.sinh`` is equivalent to computing ``-1j * jnp.sin(1j * x)``. + + See also: + - :func:`jax.numpy.cosh`: Computes the element-wise hyperbolic cosine of the + input. + - :func:`jax.numpy.tanh`: Computes the element-wise hyperbolic tangent of the + input. + - :func:`jax.numpy.arcsinh`: Computes the element-wise inverse of hyperbolic + sine of the input. + + Examples: + >>> x = jnp.array([[-2, 3, 5], + ... [0, -1, 4]]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.sinh(x) + Array([[-3.627, 10.018, 74.203], + [ 0. , -1.175, 27.29 ]], dtype=float32) + >>> with jnp.printoptions(precision=3, suppress=True): + ... -1j * jnp.sin(1j * x) + Array([[-3.627+0.j, 10.018-0.j, 74.203-0.j], + [ 0. -0.j, -1.175+0.j, 27.29 -0.j]], dtype=complex64, weak_type=True) + + For complex-valued input: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.sinh(3-2j) + Array(-4.169-9.154j, dtype=complex64, weak_type=True) + >>> with jnp.printoptions(precision=3, suppress=True): + ... -1j * jnp.sin(1j * (3-2j)) + Array(-4.169-9.154j, dtype=complex64, weak_type=True) + """ return lax.sinh(*promote_args_inexact('sinh', x)) -@implements(np.cosh, module='numpy') + @partial(jit, inline=True) def cosh(x: ArrayLike, /) -> Array: + r"""Calculate element-wise hyperbolic cosine of input. + + JAX implementation of :obj:`numpy.cosh`. + + The hyperbolic cosine is defined by: + + .. math:: + + cosh(x) = \frac{e^x + e^{-x}}{2} + + Args: + x: input array or scalar. + + Returns: + An array containing the hyperbolic cosine of each element of ``x``, promoting + to inexact dtype. + + Note: + ``jnp.cosh`` is equivalent to computing ``jnp.cos(1j * x)``. + + See also: + - :func:`jax.numpy.sinh`: Computes the element-wise hyperbolic sine of the input. + - :func:`jax.numpy.tanh`: Computes the element-wise hyperbolic tangent of the + input. + - :func:`jax.numpy.arccosh`: Computes the element-wise inverse of hyperbolic + cosine of the input. + + Examples: + >>> x = jnp.array([[3, -1, 0], + ... [4, 7, -5]]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.cosh(x) + Array([[ 10.068, 1.543, 1. ], + [ 27.308, 548.317, 74.21 ]], dtype=float32) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.cos(1j * x) + Array([[ 10.068+0.j, 1.543+0.j, 1. +0.j], + [ 27.308+0.j, 548.317+0.j, 74.21 +0.j]], dtype=complex64, weak_type=True) + + For complex-valued input: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.cosh(5+1j) + Array(40.096+62.44j, dtype=complex64, weak_type=True) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.cos(1j * (5+1j)) + Array(40.096+62.44j, dtype=complex64, weak_type=True) + """ return lax.cosh(*promote_args_inexact('cosh', x)) @implements(np.arcsinh, module='numpy') @@ -735,9 +832,57 @@ def arccosh(x: ArrayLike, /) -> Array: result = _where(real(result) < 0, lax.neg(result), result) return result -@implements(np.tanh, module='numpy') + @partial(jit, inline=True) def tanh(x: ArrayLike, /) -> Array: + r"""Calculate element-wise hyperbolic tangent of input. + + JAX implementation of :obj:`numpy.tanh`. + + The hyperbolic tangent is defined by: + + .. math:: + + tanh(x) = \frac{sinh(x)}{cosh(x)} = \frac{e^x - e^{-x}}{e^x + e^{-x}} + + Args: + x: input array or scalar. + + Returns: + An array containing the hyperbolic tangent of each element of ``x``, promoting + to inexact dtype. + + Note: + ``jnp.tanh`` is equivalent to computing ``-1j * jnp.tan(1j * x)``. + + See also: + - :func:`jax.numpy.sinh`: Computes the element-wise hyperbolic sine of the input. + - :func:`jax.numpy.cosh`: Computes the element-wise hyperbolic cosine of the + input. + - :func:`jax.numpy.arctanh`: Computes the element-wise inverse of hyperbolic + tangent of the input. + + Examples: + >>> x = jnp.array([[-1, 0, 1], + ... [3, -2, 5]]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.tanh(x) + Array([[-0.762, 0. , 0.762], + [ 0.995, -0.964, 1. ]], dtype=float32) + >>> with jnp.printoptions(precision=3, suppress=True): + ... -1j * jnp.tan(1j * x) + Array([[-0.762+0.j, 0. -0.j, 0.762-0.j], + [ 0.995-0.j, -0.964+0.j, 1. -0.j]], dtype=complex64, weak_type=True) + + For complex-valued input: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.tanh(2-5j) + Array(1.031+0.021j, dtype=complex64, weak_type=True) + >>> with jnp.printoptions(precision=3, suppress=True): + ... -1j * jnp.tan(1j * (2-5j)) + Array(1.031+0.021j, dtype=complex64, weak_type=True) + """ return lax.tanh(*promote_args_inexact('tanh', x)) @implements(np.arctanh, module='numpy')