-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #40 from leanprover/FFI
Execution of samplers through FFI
- Loading branch information
Showing
12 changed files
with
213 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import SampCert | ||
import SampCert.Samplers.Uniform.Code | ||
import SampCert.Samplers.Bernoulli.Code | ||
import SampCert.Samplers.BernoulliNegativeExponential.Code | ||
import SampCert.Samplers.Laplace.Code | ||
import SampCert.Samplers.LaplaceGen.Code | ||
import SampCert.Samplers.Geometric.Code | ||
import SampCert.Samplers.Gaussian.Code | ||
import SampCert.Samplers.GaussianGen.Code | ||
|
||
open SLang Std | ||
|
||
def comp (n : ℕ+) : SLang ℕ := do | ||
let x ← UniformPowerOfTwoSample n | ||
return x + 1 | ||
|
||
def double (n : ℕ+) : SLang (ℕ × ℕ) := do | ||
let x ← UniformPowerOfTwoSample n | ||
let y ← UniformPowerOfTwoSample n | ||
return (x,y) | ||
|
||
def main : IO Unit := do | ||
|
||
let sampleSize : ℕ := 10 | ||
let sampleNum : ℕ := 1000 | ||
|
||
let mut arr : Array ℕ := Array.mkArray (sampleSize) 0 | ||
for _ in [:sampleNum] do | ||
let r ← run <| UniformPowerOfTwoSample ⟨ sampleSize , by aesop ⟩ | ||
let v := arr[r]! | ||
arr := arr.set! r (v + 1) | ||
|
||
let mut res : Array Float := Array.mkArray (sampleSize) 0.0 | ||
for i in [:sampleSize] do | ||
let total : Float := arr[i]!.toFloat | ||
let freq : Float := total / sampleNum.toFloat | ||
res := res.set! i freq | ||
|
||
IO.println s!"Repeated uniform sampling: {res}" | ||
|
||
let mut arr2 : Array ℕ := Array.mkArray (sampleSize) 0 | ||
for _ in [:sampleNum] do | ||
let r ← run <| probWhile (fun x => x % 2 == 0) (fun _ => UniformPowerOfTwoSample ⟨ sampleSize , by aesop ⟩) 0 | ||
let v := arr2[r]! | ||
arr2 := arr2.set! r (v + 1) | ||
|
||
let mut res2 : Array Float := Array.mkArray (sampleSize) 0.0 | ||
for i in [:sampleSize] do | ||
let total : Float := arr2[i]!.toFloat | ||
let freq : Float := total / sampleNum.toFloat | ||
res2 := res2.set! i freq | ||
|
||
IO.println s!"Repeated uniform sampling with filtering: {res2}" | ||
|
||
let mut arr3 : Array ℕ := Array.mkArray (sampleSize) 0 | ||
for _ in [:sampleNum] do | ||
let r ← run <| probUntil (UniformPowerOfTwoSample ⟨ sampleSize , by aesop ⟩) (fun x => x % 2 == 0) | ||
let v := arr3[r]! | ||
arr3 := arr3.set! r (v + 1) | ||
|
||
let mut res3 : Array Float := Array.mkArray (sampleSize) 0.0 | ||
for i in [:sampleSize] do | ||
let total : Float := arr3[i]!.toFloat | ||
let freq : Float := total / sampleNum.toFloat | ||
res3 := res3.set! i freq | ||
|
||
IO.println s!"Repeated uniform sampling with filtering: {res3}" | ||
|
||
let u : ℕ ← run <| UniformSample 5 | ||
IO.println s!"**1 Uniform sample: {u}" | ||
|
||
let u : ℕ ← run <| UniformSample 10 | ||
IO.println s!"**2 Uniform sample: {u}" | ||
|
||
let u : ℕ ← run <| UniformSample 20 | ||
IO.println s!"**3 Uniform sample: {u}" | ||
|
||
let u : ℕ ← run <| UniformSample 15 | ||
IO.println s!"**4 Uniform sample: {u}" | ||
|
||
let u : ℕ × ℕ ← run <| double 15 | ||
IO.println s!"**4 Uniform sample: {u}" | ||
|
||
let mut arr4 : Array ℕ := Array.mkArray sampleSize 0 | ||
for _ in [:sampleNum] do | ||
let r ← run <| UniformSample 10 | ||
let v := arr4[r]! | ||
arr4 := arr4.set! r (v + 1) | ||
|
||
let mut res4 : Array Float := Array.mkArray sampleSize 0.0 | ||
for i in [:sampleSize] do | ||
let total : Float := arr4[i]!.toFloat | ||
let freq : Float := total / sampleNum.toFloat | ||
res4 := res4.set! i freq | ||
|
||
IO.println s!"Repeated uniform sampling plop: {res3}" | ||
|
||
let u : ℕ ← run <| UniformSample 15 | ||
IO.println s!"**4 Uniform sample: {u}" | ||
|
||
let u : Bool ← run <| BernoulliSample 1 2 (by aesop) | ||
IO.println s!"Bernoulli sample: {u}" | ||
|
||
let u : Bool ← run <| BernoulliSample 1 100 (by aesop) | ||
IO.println s!"Bernoulli sample: {u}" | ||
|
||
let u : Bool ← run <| BernoulliExpNegSample 1 2 | ||
IO.println s!"Bernoulli NE sample: {u}" | ||
|
||
let u : ℤ ← run <| DiscreteLaplaceSample 1 1 | ||
IO.println s!"Laplace sample: {u}" | ||
|
||
let u : ℤ ← run <| DiscreteLaplaceGenSample 1 1 10 | ||
IO.println s!"Laplace Gen sample: {u}" | ||
|
||
let u : ℤ ← run <| DiscreteGaussianSample 1 1 | ||
IO.println s!"Gaussian sample: {u}" | ||
|
||
let u : ℤ ← run <| DiscreteGaussianGenSample 1 1 10 | ||
IO.println s!"Gaussian Gen sample: {u}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
/** | ||
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
Released under Apache 2.0 license as described in the file LICENSE. | ||
Authors: Jean-Baptiste Tristan | ||
*/ | ||
#include <lean/lean.h> | ||
#include <iostream> | ||
#include <bit> | ||
#include <chrono> | ||
#include <random> | ||
using namespace std; | ||
|
||
typedef std::chrono::high_resolution_clock myclock; | ||
myclock::time_point beginning = myclock::now(); | ||
myclock::duration d = myclock::now() - beginning; | ||
unsigned seed = d.count(); | ||
mt19937_64 generator(seed); | ||
|
||
extern "C" lean_object * prob_UniformP2(lean_object * a, lean_object * eta) { | ||
lean_dec(eta); | ||
if (lean_is_scalar(a)) { | ||
size_t n = lean_unbox(a); | ||
if (n == 0) { | ||
lean_internal_panic("prob_UniformP2: n == 0"); | ||
} else { | ||
int lz = __countl_zero(n); | ||
int bitlength = (8*sizeof n) - lz - 1; | ||
size_t bound = 1 << bitlength; | ||
uniform_int_distribution<int> distribution(0,bound-1); | ||
size_t r = distribution(generator); | ||
lean_dec(a); | ||
return lean_box(r); | ||
} | ||
} else { | ||
lean_internal_panic("prob_UniformP2: not handling very large values yet"); | ||
} | ||
} | ||
|
||
extern "C" lean_object * prob_Pure(lean_object * a, lean_object * eta) { | ||
lean_dec(eta); | ||
return a; | ||
} | ||
|
||
extern "C" lean_object * prob_Bind(lean_object * f, lean_object * g, lean_object * eta) { | ||
lean_dec(eta); | ||
lean_object * exf = lean_apply_1(f,lean_box(0)); | ||
lean_object * pa = lean_apply_2(g,exf,lean_box(0)); | ||
return pa; | ||
} | ||
|
||
extern "C" lean_object * prob_While(lean_object * condition, lean_object * body, lean_object * init, lean_object * eta) { | ||
lean_dec(eta); | ||
lean_object * state = init; | ||
lean_inc(state); | ||
lean_inc(condition); | ||
uint8_t cond = lean_unbox(lean_apply_1(condition,state)); | ||
while (cond) { | ||
lean_inc(body); | ||
state = lean_apply_2(body,state,lean_box(0)); | ||
lean_inc(condition); | ||
lean_inc(state); | ||
cond = lean_unbox(lean_apply_1(condition,state)); | ||
} | ||
return state; | ||
} | ||
|
||
extern "C" lean_object * my_run(lean_object * a) { | ||
lean_object * comp = lean_apply_1(a,lean_box(0)); | ||
lean_object * res = lean_io_result_mk_ok(comp); | ||
return res; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters