Skip to content

Commit

Permalink
Unbreak some examples
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Jan 9, 2024
1 parent 986c44a commit 87f7b71
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 28 deletions.
27 changes: 15 additions & 12 deletions examples/mandelbrot.dx
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,28 @@
import complex
import plot

'Escape time algorithm
# Escape time algorithm

def update(c:Complex, z:Complex) -> Complex = c + (z * z)

tol = 2.0
def inBounds(z:Complex) -> Bool = complex_abs z < tol
def inBounds(z:Complex) -> Bool = complex_abs(z) < tol

def escapeTime(c:Complex) -> Float =
fst $ fold (0.0, zero) $ \i:(Fin 1000) s.
(n, z) = s
z' = update c z
(n + b_to_f (inBounds z'), z')
def escapeTime(c:Complex) -> Nat =
z <- with_state(zero :: Complex)
bounded_iter(1000, 1000) \i.
case inBounds(get(z)) of
False -> Done(i)
True ->
z := update(c, get(z))
Continue

'Evaluate on a grid and plot the results
# Evaluate on a grid and plot the results

xs = linspace (Fin 300) (-2.0) 1.0
ys = linspace (Fin 200) (-1.0) 1.0
xs = linspace(Fin 300, -2.0, 1.0)
ys = linspace(Fin 200, -1.0, 1.0)

escapeGrid = for j i. escapeTime (Complex xs[i] ys[j])
escapeGrid = each(ys) \y. each xs \x. n_to_f(escapeTime(Complex(x, y)))

This comment has been minimized.

Copy link
@duvenaud

duvenaud Jan 9, 2024

Contributor

each(ys) but each xs?


:html matshow (-escapeGrid)
:html matshow(-escapeGrid)
> <html output>
24 changes: 13 additions & 11 deletions examples/mcmc.dx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def runChain(
numSamples: Nat,
k:Key
) -> Fin numSamples => a given (a|Data) =
[k1, k2] = split_key k
[k1, k2] = split_key(n=2, k)
with_state (initialize k1) \s.
for i:(Fin numSamples).
x = step (ixkey k2 i) (get s)
Expand All @@ -24,13 +24,13 @@ def propose(
cur: a,
proposal: a,
k: Key
) -> a given (a) =
) -> a given (a:Type) =
accept = logDensity proposal > (logDensity cur + log (rand k))
select accept proposal cur

def meanAndCovariance(xs:n=>d=>Float) -> (d=>Float, d=>d=>Float) given (n|Ix, d|Ix) =
xsMean : d=>Float = (for i. sum for j. xs[j,i]) / n_to_f (size n)
xsCov : d=>d=>Float = (for i i'. sum for j.
xsMean : d=>Float = (for i:d. sum for j:n. xs[j,i]) / n_to_f (size n)
xsCov : d=>d=>Float = (for i:d i':d. sum for j:n.
(xs[j,i'] - xsMean[i']) *
(xs[j,i ] - xsMean[i ]) ) / (n_to_f (size n) - 1)
(xsMean, xsCov)
Expand All @@ -45,7 +45,7 @@ def mhStep(
k:Key,
x:d=>Float
) -> d=>Float given (d|Ix) =
[k1, k2] = split_key k
[k1, k2] = split_key(n=2, k)
proposal = x + stepSize .* randn_vec k1
propose logProb x proposal k2

Expand Down Expand Up @@ -80,8 +80,8 @@ def hmcStep(
) -> d=>Float given (d|Ix) =
def hamiltonian(s:HMCState (d=>Float)) -> Float =
logProb s.x - 0.5 * vdot s.p s.p
[k1, k2] = split_key k
p = randn_vec k1
[k1, k2] = split_key(n=2, k)
p = randn_vec k1 :: d => Float
proposal = leapfrogIntegrate params logProb HMCState(x, p)
final = propose hamiltonian HMCState(x, p) proposal k2
final.x
Expand All @@ -93,6 +93,8 @@ def hmcStep(
def myLogProb(x:(Fin 2)=>Float) -> LogProb =
x' = x - [1.5, 2.5]
neg $ 0.5 * inner x' [[1.,0.],[0.,20.]] x'
def myInitializer(k:Key) -> Fin 2 => Float =
randn_vec(k)

numSamples : Nat =
if dex_test_mode()
Expand All @@ -101,21 +103,21 @@ numSamples : Nat =
k0 = new_key 1

mhParams = 0.1
mhSamples = runChain randn_vec (\k x. mhStep mhParams myLogProb k x) numSamples k0
mhSamples = runChain myInitializer (\k x. mhStep mhParams myLogProb k x) numSamples k0

:p meanAndCovariance mhSamples
> ([0.5455918, 2.522631], [[0.3552593, 0.05022133], [0.05022133, 0.08734216]])

:html show_plot $ y_plot $
slice (map head mhSamples) 0 (Fin 1000)
slice (each mhSamples head) 0 (Fin 1000)
> <html output>

hmcParams = HMCParams(10, 0.1)
hmcSamples = runChain randn_vec (\k x. hmcStep hmcParams myLogProb k x) numSamples k0
hmcSamples = runChain myInitializer (\k x. hmcStep hmcParams myLogProb k x) numSamples k0

:p meanAndCovariance hmcSamples
> ([1.472011, 2.483082], [[1.054705, -0.002082013], [-0.002082013, 0.05058844]])

:html show_plot $ y_plot $
slice (map head hmcSamples) 0 (Fin 1000)
slice (each hmcSamples head) 0 (Fin 1000)
> <html output>
2 changes: 1 addition & 1 deletion examples/pi.dx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
'To compute $\pi$, randomly sample points in the first quadrant unit square to estimate the $\frac{A_{quadrant}}{A_{square}}$ ratio. Then, multiply by $4$.

def estimatePiArea(key:Key) -> Float =
[k1, k2] = split_key key
[k1, k2] = split_key(n=2, key)
x = rand k1
y = rand k2
inBounds = (sq x + sq y) < 1.0
Expand Down
8 changes: 4 additions & 4 deletions examples/sierpinski.dx
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@ def update(points:n=>Point, key:Key, p:Point) -> Point given (n|Ix) =
p' = points[rand_idx key]
Point(0.5 * (p.x + p'.x), 0.5 * (p.y + p'.y))

def runChain(n:Nat, f:(Key, a) -> a, key:Key, x0:a) -> Fin n => a given (a|Data) =
def runChain(n:Nat, key:Key, x0:a, f:(Key, a) -> a) -> Fin n => a given (a|Data) =
ref <- with_state x0
for i:(Fin n).
prev = get ref
new = ixkey key i | f(get ref)
ref := new
new

trianglePoints : (Fin 3)=>Point = [Point(0.0, 0.0), Point(1.0, 0.0), Point(0.5, sqrt 0.75)]

points = runChain 3000 (\k p. update trianglePoints k p) (new_key 0) (Point 0.0 0.0)
n = 3000
points = runChain n (new_key 0) (Point 0.0 0.0) \k p. update trianglePoints k p

(xs, ys) = unzip for i. (points[i].x, points[i].y)
(xs, ys) = unzip for i:(Fin n). (points[i].x, points[i].y)

:html show_plot $ xy_plot xs ys
> <html output>

0 comments on commit 87f7b71

Please sign in to comment.