Skip to content

Commit

Permalink
feat(lib): add real-valued implementation of jax.scipy.special.fresnel
Browse files Browse the repository at this point in the history
Add implementation, documentation, and tests, for both single-precision and double-precision floating-point arithmetic.
  • Loading branch information
jeertmans committed Sep 3, 2024
1 parent 5ce66dc commit f9cb95c
Show file tree
Hide file tree
Showing 4 changed files with 331 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/jax.scipy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ jax.scipy.special
expit
expn
factorial
fresnel
gamma
gammainc
gammaincc
Expand Down
322 changes: 322 additions & 0 deletions jax/_src/third_party/scipy/special.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
from __future__ import annotations

import jax.numpy as jnp
from jax import jit

from jax._src import custom_derivatives, dtypes
from jax._src.numpy.lax_numpy import complexfloating
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike


@jit
def sincospisquaredhalf(
x: Array,
) -> tuple[Array, Array]:
"""
Accurate evaluation of sin(pi * x**2 / 2) and cos(pi * x**2 / 2).
As based on the sinpi and cospi functions from SciPy, see:
- https://github.com/scipy/scipy/blob/v1.14.0/scipy/special/special/cephes/trig.h
"""
x = jnp.abs(x)
# define s = x % 2, y = x - s, then
# r = (x * x / 2) % 2
# = [(y + s)*(y + s)/2] % 2
# = [y*y/2 + s*y + s*s/2] % 2
# = [(y*y/2)%2 + (s*y + s*s/2)%2]%2
# = [0 + (s*(y+s/2))%2]%2
# = [s*(x-s/2)]%2
s = jnp.fmod(x, 2.0)
r = jnp.fmod(s * (x - s / 2), 2.0)

sinpi = jnp.where(
r < 0.5,
jnp.sin(jnp.pi * r),
jnp.where(
r > 1.5,
jnp.sin(jnp.pi * (r - 2.0)),
-jnp.sin(jnp.pi * (r - 1.0)),
),
)
cospi = jnp.where(
r == 0.5,
0.0,
jnp.where(r < 1.0, -jnp.sin(jnp.pi * (r - 0.5)), jnp.sin(jnp.pi * (r - 1.5))),
)

return sinpi, cospi


@custom_derivatives.custom_jvp
def fresnel(x: ArrayLike) -> tuple[Array, Array]:
r"""The Fresnel integrals
JAX implementation of :obj:`scipy.special.fresnel`.
The Fresnel integrals are defined as
.. math::
S(x) &= \int_0^x \sin(\pi t^2 /2) dt \\
C(x) &= \int_0^x \cos(\pi t^2 /2) dt.
Args:
x: arraylike, real-valued.
Returns:
Arrays containing the values of the Fresnel integrals.
Notes:
The JAX version only supports real-valued inputs, and
is based on the SciPy C++ implementation, see
`here
<https://github.com/scipy/scipy/blob/v1.14.0/scipy/special/special/cephes/fresnl.h>`_.
For ``float32`` dtypes, the implementation is directly based
on the Cephes implementation ``fresnlf``.
As for the original Cephes implementation, the accuracy
is only guaranteed in the domain [-10, 10]. Outside of
that domain, one could observe divergence between the
theoretical derivatives and the custom JVP implementation,
especially for large input values.
Finally, for half-precision data types, ``float16``
and ``bfloat16``, the array elements are upcasted to
``float32`` as the Cephes coefficients used in
series expansions would otherwise lead to poor results.
Other data types, like ``float8``, are not supported.
"""

xxa, = promote_args_inexact("fresnel", x)
original_dtype = xxa.dtype

# This part is mostly a direct translation of SciPy's C++ code,
# and the original Cephes implementation for single precision.

if dtypes.issubdtype(xxa.dtype, complexfloating):
raise NotImplementedError(
'Support for complex-valued inputs is not implemented yet.')
elif xxa.dtype in (jnp.float32, jnp.float16, jnp.bfloat16):
# Single-precision Cephes coefficients

# For half-precision, series expansions have either
# produce overflow or poor accuracy.
# Upcasting to single-precision is hence needed.
xxa = xxa.astype(jnp.float32) # No-op for float32

fresnl_sn = jnp.array([
+1.647629463788700e-9,
-1.522754752581096e-7,
+8.424748808502400e-6,
-3.120693124703272e-4,
+7.244727626597022e-3,
-9.228055941124598e-2,
+5.235987735681432e-1,
], dtype=jnp.float32)

fresnl_cn = jnp.array([
+1.416802502367354e-8,
-1.157231412229871e-6,
+5.387223446683264e-5,
-1.604381798862293e-3,
+2.818489036795073e-2,
-2.467398198317899e-1,
+9.999999760004487e-1,
], dtype=jnp.float32)

fresnl_fn = jnp.array([
-1.903009855649792e12,
+1.355942388050252e11,
-4.158143148511033e9,
+7.343848463587323e7,
-8.732356681548485e5,
+8.560515466275470e3,
-1.032877601091159e2,
+2.999401847870011e0,
], dtype=jnp.float32)

fresnl_gn = jnp.array([
-1.860843997624650e11,
+1.278350673393208e10,
-3.779387713202229e8,
+6.492611570598858e6,
-7.787789623358162e4,
+8.602931494734327e2,
-1.493439396592284e1,
+9.999841934744914e-1,
], dtype=jnp.float32)
elif xxa.dtype == jnp.float64:
# Double-precision Cephes coefficients

fresnl_sn = jnp.array([
-2.99181919401019853726e3,
+7.08840045257738576863e5,
-6.29741486205862506537e7,
+2.54890880573376359104e9,
-4.42979518059697779103e10,
+3.18016297876567817986e11,
], dtype=jnp.float64)

fresnl_sd = jnp.array([
+1.00000000000000000000e0,
+2.81376268889994315696e2,
+4.55847810806532581675e4,
+5.17343888770096400730e6,
+4.19320245898111231129e8,
+2.24411795645340920940e10,
+6.07366389490084639049e11,
], dtype=jnp.float64)

fresnl_cn = jnp.array([
-4.98843114573573548651e-8,
+9.50428062829859605134e-6,
-6.45191435683965050962e-4,
+1.88843319396703850064e-2,
-2.05525900955013891793e-1,
+9.99999999999999998822e-1,
], dtype=jnp.float64)

fresnl_cd = jnp.array([
+3.99982968972495980367e-12,
+9.15439215774657478799e-10,
+1.25001862479598821474e-7,
+1.22262789024179030997e-5,
+8.68029542941784300606e-4,
+4.12142090722199792936e-2,
+1.00000000000000000118e0,
], dtype=jnp.float64)

fresnl_fn = jnp.array([
+4.21543555043677546506e-1,
+1.43407919780758885261e-1,
+1.15220955073585758835e-2,
+3.45017939782574027900e-4,
+4.63613749287867322088e-6,
+3.05568983790257605827e-8,
+1.02304514164907233465e-10,
+1.72010743268161828879e-13,
+1.34283276233062758925e-16,
+3.76329711269987889006e-20,
], dtype=jnp.float64)

fresnl_fd = jnp.array([
+1.00000000000000000000e0,
+7.51586398353378947175e-1,
+1.16888925859191382142e-1,
+6.44051526508858611005e-3,
+1.55934409164153020873e-4,
+1.84627567348930545870e-6,
+1.12699224763999035261e-8,
+3.60140029589371370404e-11,
+5.88754533621578410010e-14,
+4.52001434074129701496e-17,
+1.25443237090011264384e-20,
], dtype=jnp.float64)

fresnl_gn = jnp.array([
+5.04442073643383265887e-1,
+1.97102833525523411709e-1,
+1.87648584092575249293e-2,
+6.84079380915393090172e-4,
+1.15138826111884280931e-5,
+9.82852443688422223854e-8,
+4.45344415861750144738e-10,
+1.08268041139020870318e-12,
+1.37555460633261799868e-15,
+8.36354435630677421531e-19,
+1.86958710162783235106e-22,
], dtype=jnp.float64)

fresnl_gd = jnp.array([
+1.00000000000000000000e0,
+1.47495759925128324529e0,
+3.37748989120019970451e-1,
+2.53603741420338795122e-2,
+8.14679107184306179049e-4,
+1.27545075667729118702e-5,
+1.04314589657571990585e-7,
+4.60680728146520428211e-10,
+1.10273215066240270757e-12,
+1.38796531259578871258e-15,
+8.39158816283118707363e-19,
+1.86958710162783236342e-22,
], dtype=jnp.float64)
else:
raise NotImplementedError(
f'Support for {xxa.dtype} dtype is not implemented yet.')

assert xxa.dtype in (jnp.float32, jnp.float64)
single_precision = (xxa.dtype == jnp.float32)

x = jnp.abs(xxa)

x2 = x * x

# Infinite x values
s_inf = c_inf = 0.5

# Small x values
t = x2 * x2

if single_precision:
s_small = x * x2 * jnp.polyval(fresnl_sn, t)
c_small = x * jnp.polyval(fresnl_cn, t)
else:
s_small = x * x2 * jnp.polyval(fresnl_sn[:6], t) / jnp.polyval(fresnl_sd[:7], t)
c_small = x * jnp.polyval(fresnl_cn[:6], t) / jnp.polyval(fresnl_cd[:7], t)

# Large x values

sinpi, cospi = sincospisquaredhalf(x)

if single_precision:
c_large = c_inf
s_large = s_inf
else:
c_large = 0.5 + 1 / (jnp.pi * x) * sinpi
s_large = 0.5 - 1 / (jnp.pi * x) * cospi

# Other x values
t = jnp.pi * x2
u = 1.0 / (t * t)
t = 1.0 / t

if single_precision:
f = 1.0 - u * jnp.polyval(fresnl_fn, u)
g = t * jnp.polyval(fresnl_gn, u)
else:
f = 1.0 - u * jnp.polyval(fresnl_fn, u) / jnp.polyval(fresnl_fd, u)
g = t * jnp.polyval(fresnl_gn, u) / jnp.polyval(fresnl_gd, u)

t = jnp.pi * x
c_other = 0.5 + (f * sinpi - g * cospi) / t
s_other = 0.5 - (f * cospi + g * sinpi) / t

isinf = jnp.isinf(xxa)
small = x2 < 2.5625
large = x > 36974.0
s = jnp.where(
isinf, s_inf, jnp.where(small, s_small, jnp.where(large, s_large, s_other))
)
c = jnp.where(
isinf, c_inf, jnp.where(small, c_small, jnp.where(large, c_large, c_other))
)

neg = xxa < 0.0
s = jnp.where(neg, -s, s)
c = jnp.where(neg, -c, c)

if original_dtype != xxa.dtype:
s = s.astype(original_dtype)
c = c.astype(original_dtype)

return s, c

def _fresnel_jvp(primals, tangents):
x, = primals
x_dot, = tangents
result = fresnel(x)
sinpi, cospi = sincospisquaredhalf(x)
dSdx = sinpi * x_dot
dCdx = cospi * x_dot
return result, (dSdx, dCdx)
fresnel.defjvp(_fresnel_jvp)
4 changes: 4 additions & 0 deletions jax/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,7 @@
xlogy as xlogy,
zeta as zeta,
)

from jax._src.third_party.scipy.special import (
fresnel as fresnel,
)
4 changes: 4 additions & 0 deletions tests/lax_scipy_special_functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t
op_record(
"factorial", 1, float_dtypes, jtu.rand_default, True
),
op_record(
"fresnel", 1, float_dtypes,
functools.partial(jtu.rand_default, scale=30), True
),
op_record(
"i0", 1, float_dtypes, jtu.rand_default, True
),
Expand Down

0 comments on commit f9cb95c

Please sign in to comment.