Skip to content

Commit

Permalink
derived to_coq progress, refactor forgiveness
Browse files Browse the repository at this point in the history
  • Loading branch information
rtjoa committed May 14, 2024
1 parent 7e4b0b5 commit a5fe094
Show file tree
Hide file tree
Showing 7 changed files with 271 additions and 118 deletions.
71 changes: 55 additions & 16 deletions examples/qc/benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ abstract type Generation{T} end

abstract type Property{T} end

function workload_of(::Type{<:GenerationParams{T}}) where T
T
end

function run_benchmark(
rs::RunState,
generation_params::GenerationParams{T},
Expand Down Expand Up @@ -201,12 +205,18 @@ function produce_loss(rs::RunState, m::SamplingEntropyLossMgr, epoch::Integer)
meets = m.consider(sample)
meets && (num_meeting += 1)

if meets || (m.p.rand_forgiveness && rand(rs.rng) < m.p.forgiveness)
loss_here = lpr_eq_expanded * compute(a, lpr_eq_expanded)
actual_loss_here = lpr_eq_expanded
loss_here, actual_loss_here = if meets || (m.p.rand_forgiveness && rand(rs.rng) < m.p.forgiveness)
(
lpr_eq_expanded * compute(a, lpr_eq_expanded),
lpr_eq_expanded
)
elseif !meets && !m.p.rand_forgiveness
loss_here = Dice.Constant(m.p.forgiveness) * lpr_eq_expanded * compute(a, lpr_eq_expanded)
actual_loss_here = Dice.Constant(m.p.forgiveness) * lpr_eq_expanded
(
Dice.Constant(m.p.forgiveness) * lpr_eq_expanded * compute(a, lpr_eq_expanded),
Dice.Constant(m.p.forgiveness) * lpr_eq_expanded
)
else
Dice.Constant(0), Dice.Constant(0)
end

if !meets
Expand Down Expand Up @@ -333,17 +343,43 @@ end
##################################

abstract type STLC <: Benchmark end
function sandwich(::Type{STLC})
(
"From QuickChick Require Import QuickChick. Import QcNotation.
From Coq Require Import Bool ZArith List. Import ListNotations.
From ExtLib Require Import Monad.
From ExtLib.Data.Monads Require Import OptionMonad.
Import MonadNotation.
From STLC Require Import Impl Spec.",
"Definition test_prop_SinglePreserve :=
forAll gSized (fun (e: Expr) =>
prop_SinglePreserve e).
(*! QuickChick test_prop_SinglePreserve. *)
Definition test_prop_MultiPreserve :=
forAll gSized (fun (e: Expr) =>
prop_MultiPreserve e).
(*! QuickChick test_prop_MultiPreserve. *)
"
)
end


struct STLCGeneration <: Generation{STLC}
e::Opt.T{Expr.T}
constructors_overapproximation::Vector{Opt.T{Expr.T}}
end
function generation_emit_stats(rs::RunState, g::STLCGeneration, s::String)
println_flush(rs.io, "Saving samples...")
time_sample = @elapsed with_concrete_ad_flips(rs.var_vals, g.e) do
save_samples(rs, joinpath(rs.out_dir, "terms_$(s).txt"), g.e)
end
println(rs.io, " $(time_sample) seconds")
println(rs.io)
# TODO: uncomment
# println_flush(rs.io, "Saving samples...")
# time_sample = @elapsed with_concrete_ad_flips(rs.var_vals, g.e) do
# save_samples(rs, joinpath(rs.out_dir, "terms_$(s).txt"), g.e)
# end
# println(rs.io, " $(time_sample) seconds")
# println(rs.io)
end
value(g::STLCGeneration) = g.e

Expand All @@ -353,18 +389,18 @@ value(g::STLCGeneration) = g.e

struct DerivedGenerator{T} <: GenerationParams{T}
root_ty::Type
init_size::Integer
ty_sizes::Dict{Type, Integer}
stack_size::Integer
intwidth::Integer
end
DerivedGenerator{T}(; root_ty, init_size, stack_size, intwidth) where T =
DerivedGenerator{T}(root_ty, init_size, stack_size, intwidth)
DerivedGenerator{T}(; root_ty, ty_sizes, stack_size, intwidth) where T =
DerivedGenerator{T}(root_ty, ty_sizes, stack_size, intwidth)
function to_subpath(p::DerivedGenerator{T}) where T
[
lowercase(string(T)),
"derived",
"root_ty=$(p.root_ty)",
"init_size=$(p.init_size)",
"ty-sizes=$(join(["$(ty)-$(size)" for (ty, size) in p.ty_sizes],"-"))",
"stack_size=$(p.stack_size)",
"intwidth=$(p.intwidth)",
]
Expand All @@ -377,7 +413,7 @@ function generate(rs::RunState, p::DerivedGenerator{T}) where T
end
e = generate(rs, p, add_ctor)
if T == STLC
STLCGeneration(e, constructors_overapproximation)
STLCGeneration(Opt.Some(e), constructors_overapproximation)
elseif T == BST
BSTGeneration(e, constructors_overapproximation)
elseif T == RBT
Expand All @@ -386,6 +422,9 @@ function generate(rs::RunState, p::DerivedGenerator{T}) where T
error()
end
end
function generation_params_emit_stats(rs::RunState, p::DerivedGenerator, s)
save_coq_generator(rs, p, s, derived_to_coq)
end

function save_coq_generator(rs, p, s, f)
path = joinpath(rs.out_dir, "$(s)_Generator.v")
Expand Down
46 changes: 25 additions & 21 deletions examples/qc/benchmarks/lib/stlc/dist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ end
module Expr
using Dice
using Main: DistNat, Typ
@inductive T Var(DistNat) Boolean(AnyBool) Abs(Typ.T, T) App(T, T)
@inductive T Var(DistNat) Bool(AnyBool) Abs(Typ.T, T) App(T, T)
end

to_coq(::Type{Expr.T}) = "Expr"
to_coq(::Type{Typ.T}) = "Typ"

function term_size(e::Expr.T)
match(e, [
:Var => (i) -> DistUInt32(1),
:Boolean => (b) -> DistUInt32(1),
:Bool => (b) -> DistUInt32(1),
:App => (f, x) -> DistUInt32(1) + term_size(f) + term_size(x),
:Abs => (ty, e′) -> DistUInt32(1) + term_size(e′),
])
Expand All @@ -37,23 +40,23 @@ end
function num_apps(e::Expr.T)
match(e, [
:Var => (i) -> DistUInt32(0),
:Boolean => (b) -> DistUInt32(0),
:Bool => (b) -> DistUInt32(0),
:App => (f, x) -> DistUInt32(1) + num_apps(f) + num_apps(x),
:Abs => (ty, e′) -> num_apps(e′),
])
end

stlc_ctor_to_id = Dict(
:Var => DistInt32(0),
:Boolean => DistInt32(1),
:Bool => DistInt32(1),
:App => DistInt32(2),
:Abs => DistInt32(3),
)

function ctor_to_id(ctor::Expr.T)
match(ctor, [
:Var => _ -> stlc_ctor_to_id[:Var]
:Boolean => _ -> stlc_ctor_to_id[:Boolean]
:Bool => _ -> stlc_ctor_to_id[:Bool]
:App => (_, _) -> stlc_ctor_to_id[:App]
:Abs => (_, _) -> stlc_ctor_to_id[:Abs]
])
Expand All @@ -69,7 +72,7 @@ end
function collect_constructors(e)
match(e, [
:Var => (i) -> DistVector([stlc_ctor_to_id[:Var]]),
:Boolean => (b) -> DistVector([stlc_ctor_to_id[:Boolean]]),
:Bool => (b) -> DistVector([stlc_ctor_to_id[:Bool]]),
:App => (f, x) -> prob_append(prob_extend(collect_constructors(f), collect_constructors(x)), stlc_ctor_to_id[:App]),
:Abs => (ty, e′) -> prob_append(collect_constructors(e′), stlc_ctor_to_id[:Abs]),
])
Expand Down Expand Up @@ -114,7 +117,7 @@ function stlc_str(ast, depth=0, p=free)
# i is the number of steps from the *top* of the env, see gen_var
var_depth = depth - i - 1
var_str(var_depth)
elseif name == :Boolean
elseif name == :Bool
v, = children
string(v)
elseif name == :Abs
Expand Down Expand Up @@ -189,7 +192,7 @@ function typecheck(ast::Expr.T, gamma, depth=0)::Opt.T{Typ.T}
haskey(gamma, var_depth) || return Opt.None(Typ.T)
Opt.Some(gamma[var_depth])
end,
Boolean(_) -> Opt.Some(Typ.TBool()),
Bool(_) -> Opt.Some(Typ.TBool()),
Abs(t_in, e) -> begin
gamma′ = copy(gamma)
gamma′[depth] = t_in
Expand Down Expand Up @@ -243,7 +246,7 @@ function typecheck(ast::Tuple, gamma, depth=0)
return "Unknown var $(var_str(var_depth))"
end
gamma[var_depth]
elseif name == :Boolean
elseif name == :Bool
(:TBool, [])
elseif name == :Abs
t_in, e = children
Expand Down Expand Up @@ -288,25 +291,25 @@ function eq_except_numbers(x::Expr.T, y::Expr.T)
@match x [
Var(_) -> (@match y [
Var(_) -> true,
Boolean(_) -> false,
Bool(_) -> false,
App(_, _) -> false,
Abs(_, _) -> false,
]),
Boolean(_) -> (@match y [
Bool(_) -> (@match y [
Var(_) -> false,
Boolean(_) -> true,
Bool(_) -> true,
App(_, _) -> false,
Abs(_, _) -> false,
]),
App(f1, x1) -> (@match y [
Var(_) -> false,
Boolean(_) -> false,
Bool(_) -> false,
App(f2, x2) -> eq_except_numbers(f1, f2) & eq_except_numbers(x1, x2),
Abs(_, _) -> false,
]),
Abs(ty1, e1) -> (@match y [
Var(_) -> false,
Boolean(_) -> false,
Bool(_) -> false,
App(_, _) -> false,
Abs(ty2, e2) -> eq_except_numbers(ty1, ty2) & eq_except_numbers(e1, e2),
]),
Expand All @@ -316,7 +319,7 @@ end
function has_app(x::Expr.T)
@match x [
Var(_) -> false,
Boolean(_) -> false,
Bool(_) -> false,
App(_, _) -> true,
Abs(_, e) -> has_app(e),
]
Expand All @@ -326,25 +329,25 @@ function eq_structure(x::Expr.T, y::Expr.T)
@match x [
Var(_) -> (@match y [
Var(_) -> true,
Boolean(_) -> false,
Bool(_) -> false,
App(_, _) -> false,
Abs(_, _) -> false,
]),
Boolean(_) -> (@match y [
Bool(_) -> (@match y [
Var(_) -> false,
Boolean(_) -> true,
Bool(_) -> true,
App(_, _) -> false,
Abs(_, _) -> false,
]),
App(f1, x1) -> (@match y [
Var(_) -> false,
Boolean(_) -> false,
Bool(_) -> false,
App(f2, x2) -> eq_structure(f1, f2) & eq_structure(x1, x2),
Abs(_, _) -> false,
]),
Abs(_, e1) -> (@match y [
Var(_) -> false,
Boolean(_) -> false,
Bool(_) -> false,
App(_, _) -> false,
Abs(_, e2) -> eq_structure(e1, e2),
]),
Expand Down Expand Up @@ -393,12 +396,13 @@ end
function sat_num_apps(e::Expr.T, k::DistUInt32)
@match e [
Var(_) -> DistUInt32(0),
Boolean(_) -> DistUInt32(0),
Bool(_) -> DistUInt32(0),
App(f, x) -> min(min(DistUInt32(1), k) + sat_num_apps(f, k) + sat_num_apps(x, k), k),
Abs(_, e′) -> sat_num_apps(e′, k),
]
end

# TODO: why is saturating at 1 different than eq_has_app?
function sat_eq_num_apps(x::Opt.T{T}, y::Opt.T{T}, k::Integer) where T
@match x [
Some(xv) -> (@match y [
Expand Down
4 changes: 2 additions & 2 deletions examples/qc/benchmarks/lib/stlc/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function tb_gen_expr(rs::RunState, p, size::Integer, stack_tail, track_return)
@dice_ite if flip_for(rs, "pvar", dependent_dists)
Expr.Var(DistNat(0)) # really, this is arbitrary
else
Expr.Boolean(true) # really, this is arbitrary
Expr.Bool(true) # really, this is arbitrary
end
else
sz′ = size - 1
Expand All @@ -120,7 +120,7 @@ function tb_gen_expr(rs::RunState, p, size::Integer, stack_tail, track_return)
)
Expr.Var(n)
end,
"boolean" => Expr.Boolean(flip_for(rs, "ptrue", dependent_dists)),
"bool" => Expr.Bool(flip_for(rs, "ptrue", dependent_dists)),
"abs" => begin
typ = tb_gen_type(rs, p, p.ty_size, update_stack_tail(p, stack_tail, 10))
e = tb_gen_expr(rs, p, sz′, update_stack_tail(p, stack_tail, 11), track_return)
Expand Down
6 changes: 3 additions & 3 deletions examples/qc/benchmarks/lib/stlc/to_coq_tb.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function typebased_stlc_to_coq(p, adnodes_vals, io)
expected_matchid(s) = s in ["pvar", "ptbool", "freq_var", "freq_boolean", "freq_abs", "freq_app", "ptrue", ["num$(n)" for n in twopowers(p.intwidth)]...]
expected_matchid(s) = s in ["pvar", "ptbool", "freq_var", "freq_bool", "freq_abs", "freq_app", "ptrue", ["num$(n)" for n in twopowers(p.intwidth)]...]

matchid_to_cases = Dict()
for (name, val) in adnodes_vals
Expand Down Expand Up @@ -63,7 +63,7 @@ Fixpoint manual_gen_expr (size : nat) $(join(stack_vars, " ")) : G Expr :=
(1000 - weight_var, bindGen arbitrary (fun p0 : bool => returnGen (Bool p0)))]
| S size' =>
let weight_var := $(mk_match(p.dependents, "freq_var")) in
let weight_boolean := $(mk_match(p.dependents, "freq_boolean")) in
let weight_bool := $(mk_match(p.dependents, "freq_bool")) in
let weight_abs := $(mk_match(p.dependents, "freq_abs")) in
let weight_app := $(mk_match(p.dependents, "freq_app")) in
freq [
Expand All @@ -81,7 +81,7 @@ $(
let p1 := $(join(["n$(n)" for n in twopowers(p.intwidth)], "+")) in
returnGen (Var p1))
$(")" ^ p.intwidth);
(weight_boolean,
(weight_bool,
let weight_true := $(mk_match(p.dependents, "ptrue")) in
freq [ (weight_true, returnGen (Bool true)); (1000 - weight_true, returnGen (Bool false))]
);
Expand Down
Loading

0 comments on commit a5fe094

Please sign in to comment.