-
Notifications
You must be signed in to change notification settings - Fork 7
/
fit-mala.dx
87 lines (66 loc) · 2.45 KB
/
fit-mala.dx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
'# MALA using Dex
-- load some generic utility functions
import djwutils
'## now read and process the data
dat = unsafe_io \. read_file "../pima.data"
AsList(_, tab) = parse_tsv ' ' dat
atab = map (\l. cons "1.0" l) tab
att = map (\r. list2tab r :: (Fin 9)=>String) atab
xStr = map (\r. slice r 0 (Fin 8)) att
xmb = map (\r. map parseString r) xStr :: _=>(Fin 8)=>(Maybe Float)
x = map (\r. map from_just r) xmb :: _=>(Fin 8)=>Float
yStrM = map (\r. slice r 8 (Fin 1)) att
yStr = (transpose yStrM)[0@_]
y = map (\s. select (s == "Yes") 1.0 0.0) yStr
x
y
'## now set up for MCMC
def ll(b: (Fin 8)=>Float) -> Float =
neg $ sum (log (map (\ x. (exp x) + 1) ((map (\ yi. 1 - 2*yi) y)*(x **. b))))
pscale = [10.0, 1, 1, 1, 1, 1, 1, 1] -- prior SDs
prscale = map (\ x. 1.0/x) pscale
def lprior(b: (Fin 8)=>Float) -> Float =
bs = b*prscale
neg $ sum ((log pscale) + (0.5 .* (bs*bs)))
def lpost(b: (Fin 8)=>Float) -> Float =
(ll b) + (lprior b)
def glp(b: (Fin 8)=>Float) -> (Fin 8)=>Float =
glpr = -b*prscale*prscale
gll = (transpose x) **. (y - (map (\eta. 1.0/(1.0 + eta)) (exp (-x **. b))))
glpr + gll
k = new_key 42
def mhKernel(lpost: (s) -> Float, rprop: (s, Key) -> s, dprop: (s, s) -> Float,
sll: (s, Float), k: Key) -> (s, Float) given (s) =
(x0, ll0) = sll
[k1, k2] = split_key k
x = rprop x0 k1
ll = lpost x
a = ll - ll0 + (dprop x0 x) - (dprop x x0)
u = rand k2
select (log u < a) (x, ll) (x0, ll0)
def malaKernel(lpi: (Fin n=>Float) -> Float, glpi: (Fin n=>Float) -> (Fin n)=>Float,
pre: (Fin n)=>Float, dt: Float) ->
((Fin n=>(Float), Float), Key) -> ((Fin n=>Float), Float) given (n)=
sdt = sqrt dt
spre = sqrt pre
v = dt .* pre
vinv = map (\ x. 1.0/x) v
def advance(beta: (Fin n)=>Float) -> (Fin n)=>Float =
beta + (0.5*dt) .* (pre*(glpi beta))
def rprop(beta: (Fin n)=>Float, k: Key) -> (Fin n)=>Float =
(advance beta) + sdt .* (spre*(randn_vec k))
def dprop(new: (Fin n)=>Float, old: (Fin n)=>Float) -> Float =
ao = advance old
diff = new - ao
-0.5 * sum ((log v) + diff*diff*vinv)
\s k. mhKernel lpi rprop dprop s k
pre = [100.0,1,1,1,1,1,25,1] -- diagonal pre-conditioner
kern = malaKernel lpost glp pre 1.0e-5
init = [-9.0,0,0,0,0,0,0,0]
out = markov_chain (init, -1.0e50) (\s k. step_n 1000 kern s k) 10000 k
mat = map fst out -- ditch log-posterior evaluations
mv = meanAndCovariance mat
fst mv -- mean
snd mv -- (co)variance matrix
unsafe_io \. write_file "fit-mala.tsv" (to_tsv mat)
-- eof