Skip to content

Commit

Permalink
Merge pull request #40 from leanprover/FFI
Browse files Browse the repository at this point in the history
Execution of samplers through FFI
  • Loading branch information
jtristan authored Jul 16, 2024
2 parents 06b5661 + ea62836 commit 3d900a0
Show file tree
Hide file tree
Showing 12 changed files with 213 additions and 21 deletions.
120 changes: 120 additions & 0 deletions Main.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import SampCert
import SampCert.Samplers.Uniform.Code
import SampCert.Samplers.Bernoulli.Code
import SampCert.Samplers.BernoulliNegativeExponential.Code
import SampCert.Samplers.Laplace.Code
import SampCert.Samplers.LaplaceGen.Code
import SampCert.Samplers.Geometric.Code
import SampCert.Samplers.Gaussian.Code
import SampCert.Samplers.GaussianGen.Code

open SLang Std

def comp (n : ℕ+) : SLang ℕ := do
let x ← UniformPowerOfTwoSample n
return x + 1

def double (n : ℕ+) : SLang (ℕ × ℕ) := do
let x ← UniformPowerOfTwoSample n
let y ← UniformPowerOfTwoSample n
return (x,y)

def main : IO Unit := do

let sampleSize : ℕ := 10
let sampleNum : ℕ := 1000

let mut arr : Array ℕ := Array.mkArray (sampleSize) 0
for _ in [:sampleNum] do
let r ← run <| UniformPowerOfTwoSample ⟨ sampleSize , by aesop ⟩
let v := arr[r]!
arr := arr.set! r (v + 1)

let mut res : Array Float := Array.mkArray (sampleSize) 0.0
for i in [:sampleSize] do
let total : Float := arr[i]!.toFloat
let freq : Float := total / sampleNum.toFloat
res := res.set! i freq

IO.println s!"Repeated uniform sampling: {res}"

let mut arr2 : Array ℕ := Array.mkArray (sampleSize) 0
for _ in [:sampleNum] do
let r ← run <| probWhile (fun x => x % 2 == 0) (fun _ => UniformPowerOfTwoSample ⟨ sampleSize , by aesop ⟩) 0
let v := arr2[r]!
arr2 := arr2.set! r (v + 1)

let mut res2 : Array Float := Array.mkArray (sampleSize) 0.0
for i in [:sampleSize] do
let total : Float := arr2[i]!.toFloat
let freq : Float := total / sampleNum.toFloat
res2 := res2.set! i freq

IO.println s!"Repeated uniform sampling with filtering: {res2}"

let mut arr3 : Array ℕ := Array.mkArray (sampleSize) 0
for _ in [:sampleNum] do
let r ← run <| probUntil (UniformPowerOfTwoSample ⟨ sampleSize , by aesop ⟩) (fun x => x % 2 == 0)
let v := arr3[r]!
arr3 := arr3.set! r (v + 1)

let mut res3 : Array Float := Array.mkArray (sampleSize) 0.0
for i in [:sampleSize] do
let total : Float := arr3[i]!.toFloat
let freq : Float := total / sampleNum.toFloat
res3 := res3.set! i freq

IO.println s!"Repeated uniform sampling with filtering: {res3}"

let u : ℕ ← run <| UniformSample 5
IO.println s!"**1 Uniform sample: {u}"

let u : ℕ ← run <| UniformSample 10
IO.println s!"**2 Uniform sample: {u}"

let u : ℕ ← run <| UniformSample 20
IO.println s!"**3 Uniform sample: {u}"

let u : ℕ ← run <| UniformSample 15
IO.println s!"**4 Uniform sample: {u}"

let u : ℕ × ℕ ← run <| double 15
IO.println s!"**4 Uniform sample: {u}"

let mut arr4 : Array ℕ := Array.mkArray sampleSize 0
for _ in [:sampleNum] do
let r ← run <| UniformSample 10
let v := arr4[r]!
arr4 := arr4.set! r (v + 1)

let mut res4 : Array Float := Array.mkArray sampleSize 0.0
for i in [:sampleSize] do
let total : Float := arr4[i]!.toFloat
let freq : Float := total / sampleNum.toFloat
res4 := res4.set! i freq

IO.println s!"Repeated uniform sampling plop: {res3}"

let u : ℕ ← run <| UniformSample 15
IO.println s!"**4 Uniform sample: {u}"

let u : Bool ← run <| BernoulliSample 1 2 (by aesop)
IO.println s!"Bernoulli sample: {u}"

let u : Bool ← run <| BernoulliSample 1 100 (by aesop)
IO.println s!"Bernoulli sample: {u}"

let u : Bool ← run <| BernoulliExpNegSample 1 2
IO.println s!"Bernoulli NE sample: {u}"

let u : ℤ ← run <| DiscreteLaplaceSample 1 1
IO.println s!"Laplace sample: {u}"

let u : ℤ ← run <| DiscreteLaplaceGenSample 1 1 10
IO.println s!"Laplace Gen sample: {u}"

let u : ℤ ← run <| DiscreteGaussianSample 1 1
IO.println s!"Gaussian sample: {u}"

let u : ℤ ← run <| DiscreteGaussianGenSample 1 1 10
IO.println s!"Gaussian Gen sample: {u}"
11 changes: 8 additions & 3 deletions SampCert/SLang.lean
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ space should be interpreted as discrete.

open Classical Nat ENNReal PMF

noncomputable section

/--
The monad of ``SLang`` values
-/
Expand Down Expand Up @@ -64,11 +62,13 @@ def probZero : SLang T := λ _ : T => 0
/--
The Dirac distribution as a ``SLang``.
-/
@[extern "prob_Pure"]
def probPure (a : T) : SLang T := fun a' => if a' = a then 1 else 0

/--
Monadic bind for ``SLang`` values.
-/
@[extern "prob_Bind"]
def probBind (p : SLang T) (f : T → SLang U) : SLang U :=
fun b => ∑' a, p a * f a b

Expand All @@ -81,9 +81,9 @@ instance : Monad SLang where
the number``m`` is the largest power of two that is at most ``n``.
-/
-- MARKUSDE: I would like to change this to ``probUniformP2`` once it doesn't break extraction.
@[extern "prob_UniformP2"]
def UniformPowerOfTwoSample (n : ℕ+) : SLang ℕ :=
toSLang (PMF.uniformOfFintype (Fin (2 ^ (log 2 n))))
--((PMF.uniformOfFintype (Fin (2 ^ (log 2 n)))) : PMF ℕ).1

/--
``SLang`` functional which executes ``body`` only when ``cond`` is ``false``.
Expand All @@ -108,14 +108,19 @@ def probWhileCut (cond : T → Bool) (body : T → SLang T) (n : Nat) (a : T) :
/--
``SLang`` value for an unbounded iteration of a loop.
-/
@[extern "prob_While"]
def probWhile (cond : T → Bool) (body : T → SLang T) (init : T) : SLang T :=
fun x => ⨆ (i : ℕ), (probWhileCut cond body i init x)

/--
``SLang`` value which rejects samples from ``body`` until they satisfy ``cond``.
-/
--@[extern "prob_Until"]
def probUntil (body : SLang T) (cond : T → Bool) : SLang T := do
let v ← body
probWhile (λ v : T => ¬ cond v) (λ _ : T => body) v

@[extern "my_run"]
opaque run (c : SLang T) : IO T

end SLang
2 changes: 0 additions & 2 deletions SampCert/Samplers/Bernoulli/Code.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ import SampCert.Samplers.Uniform.Code
The term ``BernoulliSample`` violates our naming scheme, but this is currently necessary for extraction.
-/

noncomputable section

namespace SLang

/--
Expand Down
3 changes: 0 additions & 3 deletions SampCert/Samplers/BernoulliNegativeExponential/Code.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ The following identifiers violate our naming scheme, but are currently necessary
- ``BernoulliExpNegSample``
-/


noncomputable section

namespace SLang

lemma halve_wf (num : Nat) (den st : PNat) (wf : num ≤ den) :
Expand Down
3 changes: 0 additions & 3 deletions SampCert/Samplers/Gaussian/Code.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ The following identifiers violate our naming scheme, but are currently necessary
- ``DiscreteGaussianSample``
-/


noncomputable section

namespace SLang

/--
Expand Down
2 changes: 0 additions & 2 deletions SampCert/Samplers/GaussianGen/Code.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ The identifier ``DiscreteGaussianGenSample`` violates our naming scheme, however
this way for parity with ``DiscreteGaussianGen``.
-/

noncomputable section

namespace SLang

/--
Expand Down
2 changes: 0 additions & 2 deletions SampCert/Samplers/Geometric/Code.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ import SampCert.SLang
# ``probGeometric`` Implementation
-/

noncomputable section

namespace SLang

section Geometric
Expand Down
2 changes: 0 additions & 2 deletions SampCert/Samplers/Laplace/Code.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ The following identifiers violate our naming scheme, but are currently necessary
- ``DiscreteLaplaceSample``
-/

noncomputable section

namespace SLang

def DiscreteLaplaceSampleLoopIn1Aux (t : PNat) : SLang (Nat × Bool) := do
Expand Down
2 changes: 0 additions & 2 deletions SampCert/Samplers/LaplaceGen/Code.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ Authors: Jean-Baptiste Tristan
import SampCert.SLang
import SampCert.Samplers.Laplace.Code

noncomputable section

namespace SLang

def DiscreteLaplaceGenSample (num : PNat) (den : PNat) (μ : ℤ) : SLang ℤ := do
Expand Down
2 changes: 0 additions & 2 deletions SampCert/Samplers/Uniform/Code.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ This file contains the implementation for a uniform sampler over a finite set.
``UniformSample`` violates our naming scheme, but this is currently necessary for extraction.
-/

noncomputable section

namespace SLang

/-- ``Slang`` term for a uniform sample over [0, n). Implemented using rejection sampling on
Expand Down
71 changes: 71 additions & 0 deletions ffi.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/**
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Jean-Baptiste Tristan
*/
#include <lean/lean.h>
#include <iostream>
#include <bit>
#include <chrono>
#include <random>
using namespace std;

typedef std::chrono::high_resolution_clock myclock;
myclock::time_point beginning = myclock::now();
myclock::duration d = myclock::now() - beginning;
unsigned seed = d.count();
mt19937_64 generator(seed);

extern "C" lean_object * prob_UniformP2(lean_object * a, lean_object * eta) {
lean_dec(eta);
if (lean_is_scalar(a)) {
size_t n = lean_unbox(a);
if (n == 0) {
lean_internal_panic("prob_UniformP2: n == 0");
} else {
int lz = __countl_zero(n);
int bitlength = (8*sizeof n) - lz - 1;
size_t bound = 1 << bitlength;
uniform_int_distribution<int> distribution(0,bound-1);
size_t r = distribution(generator);
lean_dec(a);
return lean_box(r);
}
} else {
lean_internal_panic("prob_UniformP2: not handling very large values yet");
}
}

extern "C" lean_object * prob_Pure(lean_object * a, lean_object * eta) {
lean_dec(eta);
return a;
}

extern "C" lean_object * prob_Bind(lean_object * f, lean_object * g, lean_object * eta) {
lean_dec(eta);
lean_object * exf = lean_apply_1(f,lean_box(0));
lean_object * pa = lean_apply_2(g,exf,lean_box(0));
return pa;
}

extern "C" lean_object * prob_While(lean_object * condition, lean_object * body, lean_object * init, lean_object * eta) {
lean_dec(eta);
lean_object * state = init;
lean_inc(state);
lean_inc(condition);
uint8_t cond = lean_unbox(lean_apply_1(condition,state));
while (cond) {
lean_inc(body);
state = lean_apply_2(body,state,lean_box(0));
lean_inc(condition);
lean_inc(state);
cond = lean_unbox(lean_apply_1(condition,state));
}
return state;
}

extern "C" lean_object * my_run(lean_object * a) {
lean_object * comp = lean_apply_1(a,lean_box(0));
lean_object * res = lean_io_result_mk_ok(comp);
return res;
}
14 changes: 14 additions & 0 deletions lakefile.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,17 @@ lean_lib «VMC» where
-- From doc-gen4
meta if get_config? env = some "doc" then
require «doc-gen4» from git "https://github.com/leanprover/doc-gen4" @ "main"

target ffi.o pkg : FilePath := do
let oFile := pkg.buildDir / "ffi.o"
let srcJob ← inputTextFile <| pkg.dir / "ffi.cpp"
let weakArgs := #["-I", (← getLeanIncludeDir).toString]
buildO oFile srcJob weakArgs #["-fPIC"] "c++" getLeanTrace

extern_lib libleanffi pkg := do
let ffiO ← ffi.o.fetch
let name := nameToStaticLib "leanffi"
buildStaticLib (pkg.nativeLibDir / name) #[ffiO]

lean_exe test where
root := `Main

0 comments on commit 3d900a0

Please sign in to comment.