Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement erfinv for Float32 and Float64 #1344

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lib/complex.dx
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def complex_erf(x:Complex) -> Complex =
def complex_erfc(x:Complex) -> Complex =
todo

def complex_erfinv(x:Complex) -> Complex =
todo

def complex_log1p(x:Complex) -> Complex =
case x.re == 0.0 of
True -> x
Expand Down Expand Up @@ -130,3 +133,4 @@ instance Floating(Complex)
def lgamma(x) = complex_lgamma(x)
def erf(x) = complex_erf(x)
def erfc(x) = complex_erfc(x)
def erfinv(x) = complex_erfinv(x)
125 changes: 125 additions & 0 deletions lib/prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,7 @@ interface Floating(a:Type)
lgamma : (a) -> a
erf : (a) -> a
erfc : (a) -> a
erfinv : (a) -> a

def lbeta(x:a, y:a) -> a given (a|Sub|Floating) = lgamma x + lgamma y - lgamma (x + y)

Expand All @@ -1066,6 +1067,127 @@ def float64_cosh(x:Float64) -> Float64 = %fdiv(%fadd(%exp(x), %exp(%fsub(f_to_f6
def float64_tanh(x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x)))
,%fadd(%exp(x), %exp(%fsub(f_to_f64 0.0, x))))

# Polynomial evaluation by Horner's method
def unsafe_horner(x:a, ys:n=>a) -> a given (a|Add|Mul, n|Ix) =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This duplicates the evalpoly function in the same file, though yours is maybe a little more general since it doesn't require a zero.

https://github.com/google-research/dex-lang/blob/main/lib/prelude.dx#L2225

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right, thanks for pointing this out, I should have left a comment about it somewhere. The duplication is unfortunate but was intentional, because:

  • evalpoly, nor any of the reduction machinery, isn't yet defined at this point in the prelude, and writing out a polynomial evaluation by Horner's method by hand with a lot of coefficients in multiple branches is quite tedious. (I believe one of the other special functions does this.) I was unsure about the ramifications of shuffling around the contents of the prelude to facilitate, so I left things in place.
  • evalpoly requires a VSpace and Float but we want to be able to support both Float (AKA Float32) and Float64, but the latter doesn't implement VSpace (see also Can't negate a Float64 #1345).

I tried to differentiate it from evalpoly as more of an internal helper than an intended user-facing function in part by prefixing it with unsafe_, which seemed appropriate since it directly uses two functions similarly marked as unsafe.

though yours is maybe a little more general since it doesn't require a zero.

As an aside, IIUC, I think mine does still require a zero since that's part of the Add interface and I have a|Add.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, thanks for explaining. Yes, moving things within the prelude is a punishing game of dependency tetris. @dougalm has usually been fine with me moving things around as long as the tests still pass, but I would be surprised if you could move evalpoly up 1000 lines and have it still work.

Good point about the zero, I got mixed up about which method is in which class.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it might take a bit of additional shuffling and/or finagling, but I think the Floating interface and its instances could be moved down below evalpoly. That way they'd also be able to use copysign, inf, and nan (at least for the Float32 implementation, anyway). I believe the only use of Floating prior to the definition of evalpoly is for std, which would need to be moved as well. Though even if this shuffling is done, evalpoly itself would need to be amended in order to be usable for the Float64 definition.

n' = unsafe_i_to_n(n_to_i(size n) - 1)
yield_state ys[unsafe_from_ordinal n'] \ref. rof i:(Fin n').
ref := ys[unsafe_from_ordinal (ordinal i)] + x * get ref

# `erfinv` implementations for `Float32` and `Float64` are based on those in Julia in
# https://github.com/JuliaMath/SpecialFunctions.jl, which uses the following reference:
# Blair, J. M., Edwards, C. A., & Johnson, J. H. (1976). Rational Chebyshev approximations
# for the inverse of the error function. In Mathematics of Computation (Vol. 30, Issue 136,
# pp. 827–830). American Mathematical Society (AMS).
# https://doi.org/10.1090/s0025-5718-1976-0421040-7
def float32_erfinv(x:Float32) -> Float32 =
a = select(x > 0.0, x, -x)
if a >= 1.0
then
inf = f_to_f32(1.0 / 0.0)
if x == 1.0
then inf
else
if x == -1.0
then -inf
else f_to_f32(0.0 / 0.0) # TODO: this should probably error but `error` is not defined yet
else
if a <= 0.75 # Blair table 10
then
t = x * x - 0.5625
p1 = unsafe_horner t [-0.130959967422e+2, 0.26785225760e+2, -0.9289057365e+1]
p2 = unsafe_horner t [-0.120749426297e+2, 0.30960614529e+2, -0.17149977991e+2, 0.1e+1]
f_to_f32(x * (p1 / p2))
else
if a <= 0.9375 # Blair table 29
then
t = x * x - 0.87890625
p1 = unsafe_horner t [-0.12402565221, 0.10688059574e+1, -0.19594556078e+1, 0.4230581357]
p2 = unsafe_horner t [-0.8827697997e-1, 0.8900743359, -0.21757031196e+1, 0.1e+1]
f_to_f32(x * (p1 / p2))
else # Blair table 50
t = 1.0 / %sqrt(-%log1p(-a))
p1 = unsafe_horner t [-0.8827697997e-1, 0.8900743359, -0.21757031196e+1, 0.1e+1]
p2 = unsafe_horner t [0.155024849822, 0.1385228141995e+1, 0.1e+1]
s = select(x > 0.0, t, select(x < 0.0, (-t), 0.0))
f_to_f32(p1 / (s * p2))

def float64_erfinv(x:Float64) -> Float64 =
zero64 = (zero::Float64)
one64 = (one::Float64)
a = select(x > zero64, x, %fsub(zero64, x))
if a >= one64
then
inf = %fdiv(one64, zero64)
if x == one64
then inf
else
if x == f_to_f64(-1.0)
then %fsub(zero64, inf)
else %fdiv(zero64, zero64)
else
if a <= f_to_f64(0.75) # Blair table 17
then
t = %fsub(%fmul(x, x), f_to_f64(0.5625))
p1 = unsafe_horner t [f_to_f64( 0.160304955844066229311e2),
f_to_f64(-0.90784959262960326650e2),
f_to_f64( 0.18644914861620987391e3),
f_to_f64(-0.16900142734642382420e3),
f_to_f64( 0.6545466284794487048e2),
f_to_f64(-0.864213011587247794e1),
f_to_f64( 0.1760587821390590)]
p2 = unsafe_horner t [f_to_f64( 0.147806470715138316110e2),
f_to_f64(-0.91374167024260313936e2),
f_to_f64( 0.21015790486205317714e3),
f_to_f64(-0.22210254121855132366e3),
f_to_f64( 0.10760453916055123830e3),
f_to_f64(-0.206010730328265443e2),
f_to_f64( 0.1e1)]
%fmul(x, %fdiv(p1, p2))
else
if a <= f_to_f64(0.9375) # Blair table 37
then
t = %fsub(%fmul(x, x), f_to_f64(0.87890625))
p1 = unsafe_horner t [f_to_f64(-0.152389263440726128e-1),
f_to_f64( 0.3444556924136125216),
f_to_f64(-0.29344398672542478687e1),
f_to_f64( 0.11763505705217827302e2),
f_to_f64(-0.22655292823101104193e2),
f_to_f64( 0.19121334396580330163e2),
f_to_f64(-0.5478927619598318769e1),
f_to_f64( 0.237516689024448)]
p2 = unsafe_horner t [f_to_f64(-0.108465169602059954e-1),
f_to_f64( 0.2610628885843078511),
f_to_f64(-0.24068318104393757995e1),
f_to_f64( 0.10695129973387014469e2),
f_to_f64(-0.23716715521596581025e2),
f_to_f64( 0.24640158943917284883e2),
f_to_f64(-0.10014376349783070835e2),
f_to_f64( 0.1e1)]
%fmul(x, %fdiv(p1, p2))
else # Blair table 57
t = %fdiv(one64, %sqrt(%fsub(zero64, %log1p(%fsub(zero64, a)))))
p1 = unsafe_horner t [f_to_f64(0.10501311523733438116e-3),
f_to_f64(0.1053261131423333816425e-1),
f_to_f64(0.26987802736243283544516),
f_to_f64(0.23268695788919690806414e1),
f_to_f64(0.71678547949107996810001e1),
f_to_f64(0.85475611822167827825185e1),
f_to_f64(0.68738088073543839802913e1),
f_to_f64(0.3627002483095870893002e1),
f_to_f64(0.886062739296515468149)]
p2 = unsafe_horner t [f_to_f64(0.10501266687030337690e-3),
f_to_f64(0.1053286230093332753111e-1),
f_to_f64(0.27019862373751554845553),
f_to_f64(0.23501436397970253259123e1),
f_to_f64(0.76078028785801277064351e1),
f_to_f64(0.111815861040569078273451e2),
f_to_f64(0.119487879184353966678438e2),
f_to_f64(0.81922409747269907893913e1),
f_to_f64(0.4099387907636801536145e1),
f_to_f64(0.1e1)]
s = select(x > zero64, t, select(x < zero64, %fsub(zero64, t), zero64))
%fdiv(p1, %fmul(s, p2))

instance Floating(Float64)
def exp(x) = %exp(x)
def exp2(x) = %exp2(x)
Expand All @@ -1087,6 +1209,7 @@ instance Floating(Float64)
def lgamma(x)= %lgamma(x)
def erf(x) = %erf(x)
def erfc(x) = %erfc(x)
def erfinv(x)= float64_erfinv(x)

instance Floating(Float32)
def exp(x) = %exp(x)
Expand All @@ -1109,6 +1232,7 @@ instance Floating(Float32)
def lgamma(x)= %lgamma(x)
def erf(x) = %erf(x)
def erfc(x) = %erfc(x)
def erfinv(x)= float32_erfinv(x)

'## Raw pointer operations

Expand Down Expand Up @@ -1249,6 +1373,7 @@ instance Floating(n=>a) given (a|Floating, n|Ix)
def lgamma(x) = each x lgamma
def erf(x) = each x erf
def erfc(x) = each x erfc
def erfinv(x) = each x erfinv

'### Reductions

Expand Down
15 changes: 15 additions & 0 deletions tests/eval-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,21 @@ fun = \y. sum (map n_to_f arr) + y
:p f_to_i $ round 3.6
> 4

:p erfinv(f_to_f64 0.84270079294971486934)
> 1.

:p erfinv 1.0
> inf

-- TODO: This should actually be an error since it's outside of the domain of the function
:p erfinv 2.0
> nan

:p
xs = each [-0.99, -0.75, -0.5, -0.25, 0.0, 0.25, 0.5, 0.75, 0.99] f_to_f64
erf(erfinv xs) ~~ erfinv(erf xs) && erfinv(xs) ~~ each xs \x. zero - erfinv(zero - x)
> True

s = 1.0

:p s
Expand Down