Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

special-case enum like types #189

Open
github-actions bot opened this issue May 14, 2024 · 0 comments
Open

special-case enum like types #189

github-actions bot opened this issue May 14, 2024 · 0 comments
Labels

Comments

@github-actions
Copy link

# TODO: special-case enum like types

                if param  seen && hasmethod(variants, (Type{param},))
                    push!(seen, param)
                    push!(to_visit, param)
                    push!(tys, param)
                end
            end
        end
    end
    reverse!(tys) # top order

    type_ctor_to_id = Dict()
    for ty in tys
        for (ctor, _) in variants(ty)
            type_ctor_to_id[(ty, ctor)] = length(type_ctor_to_id)
        end
    end

    tys, type_ctor_to_id
end

function generate(rs::RunState, p, track_return)
    tys, type_ctor_to_id = collect_types(p.root_ty)
    type_to_gen = Dict()
    for ty in tys
        type_to_gen[ty] = (size, stack_tail) -> begin
            zero_prefix = if size == 0 "0_" else "" end 
            dependents = (size, stack_tail)
            frequency_for(rs,  "$(zero_prefix)$(ty)_variant", dependents, [
                "$(ctor)" => ctor([
                    if param == ty
                        # TODO: special-case enum like types
                        # TODO: if recursing, pass values of sibling *enumlikes*
                        type_to_gen[param](
                            size - 1,
                            update_stack_tail(p, stack_tail, type_ctor_to_id[(ty, ctor)])
                        )
                    elseif param  tys
                        # TODO: special-case enum like types
                        type_to_gen[param](
                            p.ty_sizes[param],
                            update_stack_tail(p, stack_tail, type_ctor_to_id[(ty, ctor)])
                        )
                    elseif param == AnyBool
                        flip_for(rs, "$(zero_prefix)$(ty)_$(ctor)_$(i)", dependents)
                    elseif param == DistUInt32
                         sum(
                            @dice_ite if flip_for(rs, "$(zero_prefix)$(ty)_$(ctor)_$(i)_num$(n)", dependents)
                                DistUInt32(n)
                            else
                                DistUInt32(0)
                            end
                            for n in twopowers(p.intwidth)
                        )
                    else
                        error()
                    end
                    for (i, param) in enumerate(params)
                ]...)
                for (ctor, params) in variants(ty)
                if size != 0 || all(param != ty for param in params) 
            ])
        end
    end
    type_to_gen[p.root_ty](p.ty_sizes[p.root_ty], empty_stack(p))
end

to_coq(::Type{DistUInt32}) = "nat"
to_coq(::Type{DistInt32}) = "Z"
to_coq(::Type{AnyBool}) = "bool"

function sandwichjoin(pairs; middle, sep)
    ls = []
    rs = []
    for (l, r) in pairs
        push!(ls, l)
        push!(rs, r)
    end
    reverse!(rs)
    join(
        Iterators.flatten([
            ls, [middle], rs
        ]), sep
    )
end

function derived_to_coq(p, adnodes_vals, io)
    matchid_to_cases = Dict()
    for (name, val) in adnodes_vals
        matchid, case = split(name, "%%")
        case = "(" * join([tocoq(eval(Meta.parse(x))) for x in split(case, "%")], ", ") * ")"
        val = thousandths(val)
        push!(get!(matchid_to_cases, matchid, []), (case, val))
    end

    tys, type_ctor_to_id = collect_types(p.root_ty)

    workload = workload_of(typeof(p))
    generators = []

    stack_vars = ["(stack$(i) : nat)" for i in 1:p.stack_size]
    function mk_match(matchid)
        cases = matchid_to_cases[matchid]
        cases = sort(cases)
        "match (size, ($(join(stack_vars, ", ")))) with 
$(join([" " ^ 9 * "| $(name) => $(w)" for (name, w) in cases], "\n"))
         | _ => 500
         end"
    end

    update_stack_vars(loc) = join(stack_vars[2:end], " ") * " $(loc)"
    variants2(ty, zero_case) = if zero_case
        [
            (ctor, params)
            for (ctor, params) in variants(ty)
            if all(param != ty for param in params) 
        ]
    else
        variants(ty)
    end


    for ty in tys
        push!(generators, "
Fixpoint gen_$(to_coq(ty)) (size : nat) $(join(stack_vars, " ")) : G $(to_coq(ty)) :=
  match size with
$(join([
"  | $(if zero_case 0 else "S size'" end) => 
    $(if length(variants2(ty, zero_case)) > 1 "freq [" else "" end)
    $(join([
"    (* $(ctor) *)

    $(if length(variants2(ty, zero_case)) > 1
        "(
         $(mk_match("$(if zero_case "0_" else "" end)$(ty)_variant_$(ctor)")),
         " else "" end)
            $(sandwichjoin(
                Iterators.flatten([
                if param == ty
                    ["bindGen (gen_$(to_coq(param)) size' $(
                        update_stack_vars(type_ctor_to_id[(ty, ctor)])
                    )) (fun p$(i) : $(to_coq(param)) =>" => ")"]
                elseif param  tys
                    ["bindGen (gen_$(to_coq(param)) $(p.ty_sizes[param]) $(
                        update_stack_vars(type_ctor_to_id[(ty, ctor)])
                    )) (fun p$(i) : $(to_coq(param)) =>" => ")"]
                elseif param == AnyBool
                    ["let weight_true := $(mk_match("$(if zero_case "0_" else "" end)$(ty)_$(ctor)_$(i)")) in
                    bindGen (freq [
                        (weight_true, true);
                        (1000-weight_true, false)
                    ]) (fun p$(i) : $(to_coq(param)) =>" => ")"]
                elseif param == DistUInt32
                    [
                        "let weight_$(n) := $(mk_match("$(if zero_case "0_" else "" end)$(ty)_$(ctor)_$(i)_num$(n)")) in
                        bindGen (freq [
                            (weight_$(n), returnGen $(n));
                            (1000-weight_$(n), returnGen 0)
                        ])
                        (fun n$(n) : nat => $(if j == p.intwidth "
                        let p$(i) := $(join(["n$(n)" for n in twopowers(p.intwidth)], "+ ")) in " else "" end)
                        " => ")"
                        for (j, n) in enumerate(twopowers(p.intwidth))
                    ]
                else
                    error()
                end
                for (i, param) in enumerate(params)
                ]),
            middle="returnGen ($(ctor) $(join(["p$(i)" for i in 1:length(params)], " ")))",
            sep="\n"))
    $(if length(variants2(ty, zero_case)) > 1 ")" else "" end)
        "
        for (ctor, params) in variants2(ty, zero_case)
    ], ";\n"))
    $(if length(variants2(ty, zero_case)) > 1 "]" else "" end)"
    for zero_case in [true, false]
  ], "\n" ))
    end.")
    end

    before, after = sandwich(workload)
    "$(before)
    $(join(generators, "\n"))

Definition gSized :=
  gen_$(to_coq(p.root_ty)) $(p.ty_sizes[p.root_ty])$(" 0" ^ p.stack_size).

    $(after)"
end
@github-actions github-actions bot added the todo label May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

0 participants