Skip to content

Commit

Permalink
Merge pull request #23949 from rajasekharporeddy:testbranch1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679264758
  • Loading branch information
Google-ML-Automation committed Sep 26, 2024
2 parents 6f7ad64 + 7e6fa3e commit 46dbb65
Showing 1 changed file with 148 additions and 3 deletions.
151 changes: 148 additions & 3 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,14 +710,111 @@ def arctan(x: ArrayLike, /) -> Array:
"""
return lax.atan(*promote_args_inexact('arctan', x))

@implements(np.sinh, module='numpy')

@partial(jit, inline=True)
def sinh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise hyperbolic sine of input.
JAX implementation of :obj:`numpy.sinh`.
The hyperbolic sine is defined by:
.. math::
sinh(x) = \frac{e^x - e^{-x}}{2}
Args:
x: input array or scalar.
Returns:
An array containing the hyperbolic sine of each element of ``x``, promoting
to inexact dtype.
Note:
``jnp.sinh`` is equivalent to computing ``-1j * jnp.sin(1j * x)``.
See also:
- :func:`jax.numpy.cosh`: Computes the element-wise hyperbolic cosine of the
input.
- :func:`jax.numpy.tanh`: Computes the element-wise hyperbolic tangent of the
input.
- :func:`jax.numpy.arcsinh`: Computes the element-wise inverse of hyperbolic
sine of the input.
Examples:
>>> x = jnp.array([[-2, 3, 5],
... [0, -1, 4]])
>>> with jnp.printoptions(precision=3, suppress=True):
... jnp.sinh(x)
Array([[-3.627, 10.018, 74.203],
[ 0. , -1.175, 27.29 ]], dtype=float32)
>>> with jnp.printoptions(precision=3, suppress=True):
... -1j * jnp.sin(1j * x)
Array([[-3.627+0.j, 10.018-0.j, 74.203-0.j],
[ 0. -0.j, -1.175+0.j, 27.29 -0.j]], dtype=complex64, weak_type=True)
For complex-valued input:
>>> with jnp.printoptions(precision=3, suppress=True):
... jnp.sinh(3-2j)
Array(-4.169-9.154j, dtype=complex64, weak_type=True)
>>> with jnp.printoptions(precision=3, suppress=True):
... -1j * jnp.sin(1j * (3-2j))
Array(-4.169-9.154j, dtype=complex64, weak_type=True)
"""
return lax.sinh(*promote_args_inexact('sinh', x))

@implements(np.cosh, module='numpy')

@partial(jit, inline=True)
def cosh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise hyperbolic cosine of input.
JAX implementation of :obj:`numpy.cosh`.
The hyperbolic cosine is defined by:
.. math::
cosh(x) = \frac{e^x + e^{-x}}{2}
Args:
x: input array or scalar.
Returns:
An array containing the hyperbolic cosine of each element of ``x``, promoting
to inexact dtype.
Note:
``jnp.cosh`` is equivalent to computing ``jnp.cos(1j * x)``.
See also:
- :func:`jax.numpy.sinh`: Computes the element-wise hyperbolic sine of the input.
- :func:`jax.numpy.tanh`: Computes the element-wise hyperbolic tangent of the
input.
- :func:`jax.numpy.arccosh`: Computes the element-wise inverse of hyperbolic
cosine of the input.
Examples:
>>> x = jnp.array([[3, -1, 0],
... [4, 7, -5]])
>>> with jnp.printoptions(precision=3, suppress=True):
... jnp.cosh(x)
Array([[ 10.068, 1.543, 1. ],
[ 27.308, 548.317, 74.21 ]], dtype=float32)
>>> with jnp.printoptions(precision=3, suppress=True):
... jnp.cos(1j * x)
Array([[ 10.068+0.j, 1.543+0.j, 1. +0.j],
[ 27.308+0.j, 548.317+0.j, 74.21 +0.j]], dtype=complex64, weak_type=True)
For complex-valued input:
>>> with jnp.printoptions(precision=3, suppress=True):
... jnp.cosh(5+1j)
Array(40.096+62.44j, dtype=complex64, weak_type=True)
>>> with jnp.printoptions(precision=3, suppress=True):
... jnp.cos(1j * (5+1j))
Array(40.096+62.44j, dtype=complex64, weak_type=True)
"""
return lax.cosh(*promote_args_inexact('cosh', x))

@implements(np.arcsinh, module='numpy')
Expand All @@ -735,9 +832,57 @@ def arccosh(x: ArrayLike, /) -> Array:
result = _where(real(result) < 0, lax.neg(result), result)
return result

@implements(np.tanh, module='numpy')

@partial(jit, inline=True)
def tanh(x: ArrayLike, /) -> Array:
r"""Calculate element-wise hyperbolic tangent of input.
JAX implementation of :obj:`numpy.tanh`.
The hyperbolic tangent is defined by:
.. math::
tanh(x) = \frac{sinh(x)}{cosh(x)} = \frac{e^x - e^{-x}}{e^x + e^{-x}}
Args:
x: input array or scalar.
Returns:
An array containing the hyperbolic tangent of each element of ``x``, promoting
to inexact dtype.
Note:
``jnp.tanh`` is equivalent to computing ``-1j * jnp.tan(1j * x)``.
See also:
- :func:`jax.numpy.sinh`: Computes the element-wise hyperbolic sine of the input.
- :func:`jax.numpy.cosh`: Computes the element-wise hyperbolic cosine of the
input.
- :func:`jax.numpy.arctanh`: Computes the element-wise inverse of hyperbolic
tangent of the input.
Examples:
>>> x = jnp.array([[-1, 0, 1],
... [3, -2, 5]])
>>> with jnp.printoptions(precision=3, suppress=True):
... jnp.tanh(x)
Array([[-0.762, 0. , 0.762],
[ 0.995, -0.964, 1. ]], dtype=float32)
>>> with jnp.printoptions(precision=3, suppress=True):
... -1j * jnp.tan(1j * x)
Array([[-0.762+0.j, 0. -0.j, 0.762-0.j],
[ 0.995-0.j, -0.964+0.j, 1. -0.j]], dtype=complex64, weak_type=True)
For complex-valued input:
>>> with jnp.printoptions(precision=3, suppress=True):
... jnp.tanh(2-5j)
Array(1.031+0.021j, dtype=complex64, weak_type=True)
>>> with jnp.printoptions(precision=3, suppress=True):
... -1j * jnp.tan(1j * (2-5j))
Array(1.031+0.021j, dtype=complex64, weak_type=True)
"""
return lax.tanh(*promote_args_inexact('tanh', x))

@implements(np.arctanh, module='numpy')
Expand Down

0 comments on commit 46dbb65

Please sign in to comment.