Skip to content

Commit

Permalink
resolve some comments and finish foundations/until
Browse files Browse the repository at this point in the history
  • Loading branch information
mjdemedeiros committed May 22, 2024
1 parent 491dd30 commit 1ba7b86
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 52 deletions.
6 changes: 3 additions & 3 deletions SampCert/Foundations/UniformP2.lean
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ Evaluates the ``uniformPowerOfTwoSample`` distribution at a point inside of its
theorem uniformPowerOfTwoSample_apply (n : PNat) (x : Nat) (h : x < 2 ^ (log 2 n)) :
(uniformPowerOfTwoSample n) x = 1 / (2 ^ (log 2 n)) := by
simp only [uniformPowerOfTwoSample, Lean.Internal.coeM, Bind.bind, Pure.pure, CoeT.coe,
CoeHTCT.coe, CoeHTC.coe, CoeOTC.coe, CoeOut.coe, toSubPMF_apply, PMF.bind_apply,
CoeHTCT.coe, CoeHTC.coe, CoeOTC.coe, CoeOut.coe, toSLang_apply, PMF.bind_apply,
uniformOfFintype_apply, Fintype.card_fin, cast_pow, cast_ofNat, PMF.pure_apply, one_div]
rw [ENNReal.tsum_mul_left]
rw [sum_indicator_finrange (2 ^ (log 2 n)) x]
Expand All @@ -109,7 +109,7 @@ Evaluates the ``uniformPowerOfTwoSample`` distribution at a point outside of its
theorem uniformPowerOfTwoSample_apply' (n : PNat) (x : Nat) (h : x ≥ 2 ^ (log 2 n)) :
uniformPowerOfTwoSample n x = 0 := by
simp only [uniformPowerOfTwoSample, Lean.Internal.coeM, Bind.bind, Pure.pure, CoeT.coe,
CoeHTCT.coe, CoeHTC.coe, CoeOTC.coe, CoeOut.coe, toSubPMF_apply, PMF.bind_apply,
CoeHTCT.coe, CoeHTC.coe, CoeOTC.coe, CoeOut.coe, toSLang_apply, PMF.bind_apply,
uniformOfFintype_apply, Fintype.card_fin, cast_pow, cast_ofNat, PMF.pure_apply,
ENNReal.tsum_eq_zero, _root_.mul_eq_zero, ENNReal.inv_eq_zero, ENNReal.pow_eq_top_iff,
ENNReal.two_ne_top, ne_eq, log_eq_zero_iff, reduceLE, or_false, not_lt, false_and, false_or]
Expand Down Expand Up @@ -149,7 +149,7 @@ theorem uniformPowerOfTwoSample_normalizes (n : PNat) :
. simp only [ge_iff_le, le_add_iff_nonneg_left, _root_.zero_le, uniformPowerOfTwoSample_apply',
tsum_zero, add_zero]
simp only [uniformPowerOfTwoSample, Lean.Internal.coeM, Bind.bind, Pure.pure, CoeT.coe,
CoeHTCT.coe, CoeHTC.coe, CoeOTC.coe, CoeOut.coe, toSubPMF_apply, PMF.bind_apply,
CoeHTCT.coe, CoeHTC.coe, CoeOTC.coe, CoeOut.coe, toSLang_apply, PMF.bind_apply,
uniformOfFintype_apply, Fintype.card_fin, cast_pow, cast_ofNat, PMF.pure_apply]
rw [Finset.sum_range]
conv =>
Expand Down
58 changes: 41 additions & 17 deletions SampCert/Foundations/Until.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import SampCert.Util.Util
/-!
# Until
Evaluation lemmas for the ``until`` term of ``SLang``
Evaluation lemmas for the ``until`` term of ``SLang``.
-/

Expand All @@ -25,7 +25,9 @@ variable {T : Type}
-- Might make some of the proofs simpler and help put them in SNF. Any extraction reason not to?

-- MARKUSDE: Maybe needs better name, since it's not about until
/-- Truncation of ``until`` program to zero unrollings is identically zero -/
/--
Truncation of ``until`` program to zero unrollings is identically zero
-/
@[simp]
theorem until_zero (st : T) (body : SLang T) (cond : T → Bool) (x : T) :
probWhileCut (fun v => decide (cond v = false)) (fun _ => body) 0 st x = 0 := by
Expand All @@ -35,8 +37,10 @@ theorem until_zero (st : T) (body : SLang T) (cond : T → Bool) (x : T) :
-- MARKUSDE: These lemmas anger the simplifier, since it might first simplify ``decide (... = false)``.
-- Is this a problem?

/-- Truncation of ``until`` program to any number of unrollings evaluates to zero for
values which do not satisfy ``cond``. -/
/--
Truncation of ``until`` program to any number of unrollings will evaluate to zero for
values which do not satisfy ``cond``.
-/
@[simp]
theorem repeat_apply_unsat (body : SLang T) (cond : T → Bool) (fuel : ℕ) (i x : T) (h : ¬ cond x) :
probWhileCut (fun v => decide (cond v = false)) (fun _ => body) fuel i x = 0 := by
Expand All @@ -57,7 +61,9 @@ theorem repeat_apply_unsat (body : SLang T) (cond : T → Bool) (fuel : ℕ) (i
simp [h'] at h
. simp

/-- ``until`` evaluates to zero for values which do not satisfy ``cond`` -/
/--
``until`` evaluates to zero at all values which do not satisfy ``cond``
-/
@[simp]
theorem prob_until_apply_unsat (body : SLang T) (cond : T → Bool) (x : T) (h : ¬ cond x) :
probUntil (body : SLang T) (cond : T → Bool) x = 0 := by
Expand Down Expand Up @@ -104,8 +110,10 @@ theorem repeat_1 (body : SLang T) (cond : T → Bool) (x : T) (h : cond x) :
rw [if_simpl]
simp

-- MARKUSDE: move to util?
/-- Split a conditional series by the condition -/
-- MARKUSDE: move to util with the other series lemmas?
/--
Split a conditional series by the condition
-/
lemma tsum_split_ite_exp (cond : T → Bool) (f g : T → ENNReal) :
(∑' (i : T), if cond i = false then f i else g i)
= (∑' i : T, if cond i = false then f i else 0) + (∑' i : T, if cond i = true then g i else 0) := by
Expand All @@ -126,6 +134,9 @@ lemma tsum_split_ite_exp (cond : T → Bool) (f g : T → ENNReal) :
contradiction

-- MARKUSDE: TODO rename or implement repeat
/--
Closed form for truncated version of ``until``
-/
theorem repeat_closed_form (body : SLang T) (cond : T → Bool) (fuel : ℕ) (x : T) (h1 : cond x) :
∑' (i : T), body i * probWhileCut (fun v => decide (cond v = false)) (fun _ => body) fuel i x
= ∑ i in range fuel, body x * (∑' x : T, if cond x then 0 else body x)^i := by
Expand Down Expand Up @@ -198,25 +209,30 @@ theorem repeat_closed_form (body : SLang T) (cond : T → Bool) (fuel : ℕ) (x

-- MARKUSDE: TODO/rename
-- MARKUSDE: This is simple, where is it used?
theorem convergence (body : SLang T) (cond : T → Bool) (x : T) :
/--
Expression for the limit of the closed form of truncated ``until``
-/
lemma convergence (body : SLang T) (cond : T → Bool) (x : T) :
⨆ fuel, ∑ i in range fuel, body x * (∑' x : T, if cond x then 0 else body x)^i
= body x * (1 - ∑' x : T, if cond x then 0 else body x)⁻¹ := by
rw [← ENNReal.tsum_eq_iSup_nat]
rw [ENNReal.tsum_mul_left]
rw [ENNReal.tsum_geometric]

-- MARKUSDE: TODO/rename (or define repeat)
/-- Truncated ``until`` term is monotone in the maximum number of steps -/
/--
Truncated ``until`` term is monotone (as in pointwise, with results in ℝ≥0∞) in the maximum number of steps.
-/
theorem repeat_monotone (body : SLang T) (cond : T → Bool) (x : T) :
∀ (a : T), Monotone fun i => body a * probWhileCut (fun v => decide (cond v = false)) (fun _ => body) i a x := by
intro a
have A := @probWhileCut_monotonic T (fun v => decide (cond v = false)) (fun _ => body) a x
exact Monotone.const_mul' A (body a)

-- MARKUSDE: err-- what if this sum is 1? What if it's greater than 1? Is ``until`` only meaninfgul when
-- body is normalized?
-- MARKUSDE: reduce proof
/-- ``until`` term evaluates to ``body``, scaled by ?? -/
-- MARKUSDE: err-- what if this sum is 1? What if it's greater than 1? Is ``until`` only meaninfgul when body is normalized?
/--
``until`` term evaluates to ``body``, normalizing by the total mass of elements which satisfy ``cond``.
-/
@[simp]
theorem prob_until_apply_sat (body : SLang T) (cond : T → Bool) (x : T) (h : cond x) :
probUntil (body : SLang T) (cond : T → Bool) x
Expand Down Expand Up @@ -249,7 +265,12 @@ theorem prob_until_apply_sat (body : SLang T) (cond : T → Bool) (x : T) (h : c
intro j
rw [← ENNReal.tsum_eq_iSup_sum]

-- MARKUSDE: ??
/--
Closed form for evaluation of ``until``. ``until`` is:
- zero outside support of ``cond``
- ``body`` inside the support of ``cond``
rescaled by the total mass outside the support of ``cond``.
-/
@[simp]
theorem prob_until_apply (body : SLang T) (cond : T → Bool) (x : T) :
probUntil (body : SLang T) (cond : T → Bool) x =
Expand All @@ -260,8 +281,12 @@ theorem prob_until_apply (body : SLang T) (cond : T → Bool) (x : T) :
. rename_i h
simp [h, prob_until_apply_unsat]

-- MARKUSDE: Is this not the same conclusion as the last lemma?
-- MARKUSDE: How is norm used?
/--
When ``body`` is a proper PMF, ``until`` is
- zero outside the support of ``cond``
- ``body`` inside the support of ``cond``
normalized into a PMF.
-/
@[simp]
theorem prob_until_apply_norm (body : SLang T) (cond : T → Bool) (x : T) (norm : ∑' x : T, body x = 1) :
probUntil (body : SLang T) (cond : T → Bool) x =
Expand Down Expand Up @@ -290,4 +315,3 @@ theorem prob_until_apply_norm (body : SLang T) (cond : T → Bool) (x : T) (norm
rw [ENNReal.add_sub_cancel_right F]

end SLang
-- #lint docBlame
10 changes: 7 additions & 3 deletions SampCert/Foundations/While.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Mathlib.Probability.ProbabilityMassFunction.Constructions
/-!
# While
This file proves properties about the ``while`` term of ``SLang``
This file proves properties about the ``while`` term of ``SLang``.
-/

Expand All @@ -23,7 +23,9 @@ variable {T} [Preorder T]

-- MARKUSDE: Is it true that ≤ ordering between SLang terms here is pointwise? How would I check?
-- In Rocq, I would Unset Printing Notations-- what about Lean?
/-- The ``while`` program is monotonic (as a pointwise function) in terms of the number of unrollings. -/
/--
The ``while`` program is monotonic (as a pointwise function) in terms of the number of unrollings.
-/
theorem probWhileCut_monotonic (cond : T → Bool) (body : T → SLang T) (init : T) (x : T) :
Monotone (fun n : Nat => probWhileCut cond body n init x) := by
apply monotone_nat_of_le_succ
Expand All @@ -46,7 +48,9 @@ theorem probWhileCut_monotonic (cond : T → Bool) (body : T → SLang T) (init
exact IH a
. simp

/-- The ``probWhile`` term evaluates to the pointwise limit of the ``probWhileCut`` term -/
/--
The ``probWhile`` term evaluates to the pointwise limit of the ``probWhileCut`` term
-/
@[simp]
theorem probWhile_apply (cond : T → Bool) (body : T → SLang T) (init : T) (x : T) (v : ENNReal) :
Filter.Tendsto (fun i => probWhileCut cond body i init x) Filter.atTop (nhds v) →
Expand Down
56 changes: 40 additions & 16 deletions SampCert/SLang.lean
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,20 @@ open Classical Nat ENNReal PMF

noncomputable section

/-- The monad of ``SLang`` values -/
/--
The monad of ``SLang`` values
-/
abbrev SLang.{u} (α : Type u) : Type u := α → ℝ≥0

namespace PMF

/-! ## Coercions between a ``SLang`` value and PMF -/

/-- ``SLang`` value from explicit probability mass function -/
/--
``SLang`` value from ``PMF``
-/
def toSLang (p : PMF α) : SLang α := p.1

-- MARKUSDE: SubPMF? This is misnamed, right?
@[simp]
theorem toSubPMF_apply (p : PMF α) (x : α) : (toSLang p x) = p x := by
theorem toSLang_apply (p : PMF α) (x : α) : (toSLang p x) = p x := by
unfold toSLang
unfold DFunLike.coe
unfold instFunLike
Expand All @@ -51,31 +52,43 @@ end PMF

namespace SLang

/-! ### Program terms of ``SLang`` -/
/-!
### Program terms of ``SLang``
-/

variable {T U : Type}

/-- The zero distribution as a ``SLang value`` -/
/--
The zero distribution as a ``SLang``
-/
def zero : SLang T := λ _ : T => 0

/-- The Dirac distribution as a ``SLang value`` -/
/--
The Dirac distribution as a ``SLang``
-/
def pure (a : T) : SLang T := fun a' => if a' = a then 1 else 0

/-- Monadic bind for ``SLang`` values -/
/--
Monadic bind for ``SLang`` values
-/
def bind (p : SLang T) (f : T → SLang U) : SLang U :=
fun b => ∑' a, p a * f a b

instance : Monad SLang where
pure a := pure a
bind pa pb := pa.bind pb

/-- ``SLang`` value for the uniform distribution over ``m`` elements, where ``m``
is the largest power of two that is at most ``n``. -/
/--
``SLang`` value for the uniform distribution over ``m`` elements;
the number``m`` is the largest power of two that is at most ``n``.
-/
def uniformPowerOfTwoSample (n : ℕ+) : SLang ℕ :=
toSLang (PMF.uniformOfFintype (Fin (2 ^ (log 2 n))))
--((PMF.uniformOfFintype (Fin (2 ^ (log 2 n)))) : PMF ℕ).1

/-- ``SLang`` functional which executes ``body`` only when ``cond`` is ``false`` -/
/--
``SLang`` functional which executes ``body`` only when ``cond`` is ``false``.
-/
def whileFunctional (cond : T → Bool) (body : T → SLang T) (wh : T → SLang T) : T → SLang T :=
λ a : T =>
if cond a
Expand All @@ -84,21 +97,32 @@ def whileFunctional (cond : T → Bool) (body : T → SLang T) (wh : T → SLang
wh v
else return a

/-- ``SLang`` value obtained by unrolling a loop body exactly ``n`` times -/
-- MARKUSDE: Rename me
/--
``SLang`` value obtained by unrolling a loop body exactly ``n`` times
-/
def probWhileCut (cond : T → Bool) (body : T → SLang T) (n : Nat) (a : T) : SLang T :=
match n with
| Nat.zero => zero
| succ n => whileFunctional cond body (probWhileCut cond body n) a

/-- ``SLang`` value for an unbounded iteration of a loop -/
-- MARKUSDE: Rename me
/--
``SLang`` value for an unbounded iteration of a loop
-/
def probWhile (cond : T → Bool) (body : T → SLang T) (init : T) : SLang T :=
fun x => ⨆ (i : ℕ), (probWhileCut cond body i init x)

/-- ``SLang`` value which rejects samples from ``body`` until they satisfy ``cond`` -/
-- MARKUSDE: Rename me
/--
``SLang`` value which rejects samples from ``body`` until they satisfy ``cond``
-/
def probUntil (body : SLang T) (cond : T → Bool) : SLang T := do
let v ← body
probWhile (λ v : T => ¬ cond v) (λ _ : T => body) v

-- MARKUSDE: Possibly define a truncated ``until`` operator? Many lemmas stated this way

end SLang

#lint docBlame
3 changes: 1 addition & 2 deletions SampCert/Samplers/BernoulliNegativeExponential/Code.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ on ``Bool`` which samples ``true`` with probability ``exp (-num / den)``.
This implementation uses an method for sampling from the BNE widely known as ``Von Neumann's algorithm``.
MARKUSDE: Is this really true? This implementation is different from the one I've seen, which explicitly
samples sequences of decreasing reals. "Discrete Von Neumann" maybe?
MARKUSDE: Cite?
-/

-- MARKUSDE: FIXME: mane of samplers violate the naming scheme
Expand Down
20 changes: 17 additions & 3 deletions SampCert/Samplers/Geometric/Properties.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Jean-Baptiste Tristan
-/

import SampCert.Foundations.Basic
import SampCert.Samplers.Geometric.Code

/-!
# Properties of ``geometricSample``
MARKUSDE: which ones?
-/

noncomputable section

open Classical Nat
Expand All @@ -19,11 +24,16 @@ variable (trial : SLang Bool)
variable (trial_spec : trial false + trial true = 1)
variable (trial_spec' : trial true < 1)

theorem ite_test (a b : ℕ) (x y : ENNReal) :

-- MARKUSDE: Instance coercions? Why?
lemma ite_test (a b : ℕ) (x y : ENNReal) :
@ite ENNReal (a = b) (propDecidable (a = b)) x y
= @ite ENNReal (a = b) (instDecidableEqNat a b) x y := by
split ; any_goals { trivial }

/--
Trial distributions are determined by one element
-/
theorem trial_one_minus :
trial false = 1 - trial true := by
by_contra h
Expand All @@ -34,7 +44,8 @@ theorem trial_one_minus :
rw [h'] at trial_spec
simp at trial_spec

theorem trial_le_1 (i : ℕ) :
-- MARKUSDE: surely this is completely removable?
lemma trial_le_1 (i : ℕ) :
trial true ^ i ≤ 1 := by
induction i
. simp
Expand All @@ -53,6 +64,7 @@ theorem trial_le_1 (i : ℕ) :
simp at B
exact Left.mul_le_one IH A

-- MARKUSDE TODO/what is ⊤
theorem trial_sum_ne_top :
(∑' (n : ℕ), trial true ^ n) ≠ ⊤ := by
rw [ENNReal.tsum_geometric]
Expand Down Expand Up @@ -387,3 +399,5 @@ theorem geometric_normalizes' :
end Geometric

end SLang

-- #lint docBlame
2 changes: 1 addition & 1 deletion SampCert/Samplers/Laplace/Code.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ This file implements a sampler for the discrete laplace distribution in ``SLang`
## Implementation notes
MARKUSDE: what is the discrete laplace?
MARKUSDE: cite?
-/

Expand Down
Loading

0 comments on commit 1ba7b86

Please sign in to comment.