-
Notifications
You must be signed in to change notification settings - Fork 1
/
rootfind.sml
182 lines (156 loc) · 5.73 KB
/
rootfind.sml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
structure RootFind =
struct
(* Numeric root finding
Based on Ocaml code by Edgar Friendly <thelema314@gmail.com>
Use under LGPL2.1 license + OCaml linking exception
Does not assume differentiability of the function whose roots are
being found.
Implements the following algorithms:
* Bisection (c = (a+b)/2)
* Secant method (linear interpolation from a and b)
* Brent's method (bisection + secant + quadratic interpolation)
*)
val debug = false
val trace = false
val B = Printf.B
val I = Printf.I
val R = Printf.R
val ` = Printf.`
val $ = Printf.$
(* Internal function for brent's method *)
fun brent_int xdelta ydelta f a fa b fb c fc mflag d i =
(if debug
then Printf.printf `"a:"R `" fa:"R `" b:"R `" fb:"R `" c:"R `" fc:"R `"\n" $ a fa b fb c fc
else ();
let open Real
val s = (if Real.!= (fb, fc) andalso Real.!= (fa, fc) then
(* inverse quadratic interpolation *)
(if debug then Printf.printf `"IQI" $ else ();
a * fb * fc / (fa - fb) / (fa - fc) +
b * fa * fc / (fb - fa) / (fb - fc) +
c * fa * fb / (fc - fa) / (fc - fb))
else (* secant rule *)
(if debug then Printf.printf `" S " $ else ();
b - fb * (b - a) / (fb - fa)))
in
if debug then Printf.printf `"s0:"R`"\n" $ s else ();
(* condition 1-5 to reject above and use bisection instead *)
let val c1 = if a < b
then s < (3.0 * a + b) * 0.25 andalso s > b
else s > (3.0 * a + b) * 0.25 andalso s < b
val c2 = (mflag andalso Real.abs (s - b) >= Real.abs (b - c) / 2.0)
val c3 = (not mflag andalso Real.abs (s - b) >= Real.abs (c - d) / 2.0)
val c4 = (mflag andalso Real.abs (b - c) < xdelta)
val c5 = (not mflag andalso Real.abs (c - d) < xdelta)
in
if debug then Printf.printf `"c1: "B `"c2: "B `"c3: "B `"c4: "B `"c5: "B `"\n" $ c1 c2 c3 c4 c5 else ();
let val (s, mflag) = (* TODO: don't compute all conditions *)
if c1 orelse c2 orelse c3 orelse c4 orelse c5 then ((a + b) / 2.0, true)
else (s, false)
in
let val fs = f s
val _ = if trace then Printf.printf `"p_"I `":"R `" f(p_"I `")="R `"\n" $ i s i fs else ();
in
if Real.abs fs < ydelta
then s
else
if fa * fs < 0.0 then
brent_int_swap xdelta ydelta f a fa s fs b fb mflag c (Int.+(i,1))
else
brent_int_swap xdelta ydelta f s fs b fb b fb mflag c (Int.+(i,1))
end
end
end
end)
(* helper for a-b swapping and xdelts checks *)
and brent_int_swap xdelta ydelta f a fa b fb c fc mflag d i =
(* finish rootfinding if our range is smaller than xdelta *)
if Real.abs (b-a) < xdelta then b
else
(* ensure that fb is the best estimate so far by swapping b with a *)
if Real.abs fa < Real.abs fb then
brent_int xdelta ydelta f b fb a fa c fc mflag d i
else
brent_int xdelta ydelta f a fa b fb c fc mflag d i
fun error_bracket (loc, a, fa, b, fb) =
raise Fail (loc ^ ": root must be bracketed:" ^
" f(" ^ (Real.toString a) ^ ") = " ^ (Real.toString fa) ^
" f(" ^ (Real.toString b) ^ ") = " ^ (Real.toString fa))
fun brent delta f a b =
let val fa = f a
val fb = f b
in
if fa * fb >= 0.0
then error_bracket ("RootFind.brent", a, fa, b, fb)
(* xdelta = ydelta = delta *)
else brent_int_swap delta delta f a fa b fb a fa true 0.0 1
end
fun bisect_int delta f a fa b fb =
let val m = (a + b) * 0.5
val fm = f m
in
if Real.abs fm < delta orelse (b-a) * 0.5 < delta then m
else
(if fa * fm < 0.0 then
bisect_int delta f a fa m fm
else
bisect_int delta f m fm b fb)
end
fun bisection delta f a b =
let val fa = f a
val fb = f b
in
if fa * fb >= 0.0
then error_bracket ("RootFind.bisection", a, fa, b, fb)
else bisect_int delta f a fa b fb
end
fun secant_int delta f a fa b fb =
let val m = b - fb * (b - a) / (fb - fa)
val fm = f m
in
if Real.abs fm < delta orelse (b-a) * 0.5 < delta
then m
else (if fa * fm < 0.0 then
bisect_int delta f a fa m fm
else
bisect_int delta f m fm b fb)
end
fun secant delta f a b =
let val fa = f a
val fb = f b
in
if fa * fb >= 0.0
then error_bracket ("RootFind.secant", a, fa, fa, fb)
else secant_int delta f a fa b fb
end
fun f x = 4.0 * x * x * x - 16.0 * x * x + 17.0 * x - 4.0
(* Actual roots:
0.3285384586114149
1.2646582900644197
2.4068032513241651
*)
fun f1 x = (x + 3.0) * (x - 1.0) * (x - 1.0)
(* roots: -3, 1 (double root) *)
fun f2 x = (Math.tan x) - 2.0 * x
(* root: 1.16556118520721 *)
fun test get_root =
let
val root0 = get_root f 0.0 1.0
val root1 = get_root f 1.0 2.0
val root2 = get_root f 2.0 3.0
in
Printf.printf `"Roots of 4x^3-16x^2+17x-4 are:" `"\n"R `"\n"R `"\n"R `"\n" $ root0 root1 root2;
let val root0 = get_root f1 (~4.0) (4.0 / 3.0)
in
Printf.printf `"One root of (x+3)(x-1)^2 is:" `"\n"R `"\n" $ root0;
let val root0 = get_root f2 0.5 1.5 in
Printf.printf `"One solution of tan x = 2x is:" `"\n"R `"\n" $ root0
end
end
end
val delta = 1E~15
fun test_all () =
(test (brent delta);
test (secant delta);
test (bisection delta))
end