Skip to content

Commit

Permalink
Merge Records into ADT
Browse files Browse the repository at this point in the history
  • Loading branch information
Halbaroth committed Oct 22, 2024
1 parent 7ee9629 commit 4ac01de
Show file tree
Hide file tree
Showing 18 changed files with 105 additions and 541 deletions.
2 changes: 1 addition & 1 deletion src/lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
Ac Arith Arrays_rel Bitv Ccx Shostak Relation
Fun_sat Fun_sat_frontend Inequalities Bitv_rel Th_util Adt Adt_rel
Instances IntervalCalculus Intervals_intf Intervals_core Intervals
Ite_rel Matching Matching_types Polynome Records Records_rel
Ite_rel Matching Matching_types Polynome
Satml_frontend_hybrid Satml_frontend Satml Sat_solver Sat_solver_sig
Sig Sig_rel Theory Uf Use Domains Domains_intf Rel_utils Bitlist
; structures
Expand Down
22 changes: 1 addition & 21 deletions src/lib/frontend/models.ml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ module Pp_smtlib_term = struct
asprintf "%a" Ty.pp_smtlib t

let rec print fmt t =
let {Expr.f;xs;ty; _} = Expr.term_view t in
let {Expr.f;xs; _} = Expr.term_view t in
match f, xs with

| Sy.Lit lit, xs ->
Expand Down Expand Up @@ -151,26 +151,6 @@ module Pp_smtlib_term = struct
| Sy.Op Sy.Extract (i, j), [e] ->
fprintf fmt "%a^{%d,%d}" print e i j

| Sy.Op (Sy.Access field), [e] ->
if Options.get_output_smtlib () then
fprintf fmt "(%a %a)" DE.Term.Const.print field print e
else
fprintf fmt "%a.%a" print e DE.Term.Const.print field

| Sy.Op (Sy.Record), _ ->
begin match ty with
| Ty.Trecord { Ty.lbs = lbs; _ } ->
assert (List.length xs = List.length lbs);
fprintf fmt "{";
ignore (List.fold_left2 (fun first (field,_) e ->
fprintf fmt "%s%a = %a" (if first then "" else "; ")
DE.Term.Const.print field print e;
false
) true lbs xs);
fprintf fmt "}";
| _ -> assert false
end

(* TODO: introduce PrefixOp in the future to simplify this ? *)
| Sy.Op op, [e1; e2] when op == Sy.Pow || op == Sy.Integer_round ||
op == Sy.Max_real || op == Sy.Max_int ||
Expand Down
158 changes: 11 additions & 147 deletions src/lib/frontend/translate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -519,85 +519,17 @@ let rec dty_to_ty ?(update = false) ?(is_var = false) dty =
| _ -> unsupported "Type %a" DE.Ty.print dty

and handle_ty_app ?(update = false) ty_c l =
(* Applies the substitutions in [tysubsts] to each encountered type
variable. *)
let rec apply_ty_substs tysubsts ty =
match ty with
| Ty.Tvar { v; _ } ->
Ty.M.find v tysubsts

| Text (tyl, hs) ->
Ty.Text (List.map (apply_ty_substs tysubsts) tyl, hs)

| Tfarray (ti, tv) ->
Tfarray (
apply_ty_substs tysubsts ti,
apply_ty_substs tysubsts tv
)

| Tadt (hs, tyl) ->
Tadt (hs, List.map (apply_ty_substs tysubsts) tyl)

| Trecord ({ args; lbs; _ } as rcrd) ->
Trecord {
rcrd with
args = List.map (apply_ty_substs tysubsts) args;
lbs = List.map (
fun (hs, t) ->
hs, apply_ty_substs tysubsts t
) lbs;
}

| _ -> ty
in
let tyl = List.map (dty_to_ty ~update) l in
(* Recover the initial versions of the types and apply them on the provided
type arguments stored in [tyl]. *)
match Cache.find_ty ty_c with
| Tadt (hs, _) -> Tadt (hs, tyl)

| Trecord { args; _ } as ty ->
let tysubsts =
List.fold_left2 (
fun acc tv ty ->
match tv with
| Ty.Tvar { v; _ } -> Ty.M.add v ty acc
| _ -> assert false
) Ty.M.empty args tyl
in
apply_ty_substs tysubsts ty

| Tadt (hs, _, record) -> Tadt (hs, tyl, record)
| Text (_, s) -> Text (tyl, s)
| _ -> assert false

(** Handles a simple type declaration. *)
let mk_ty_decl (ty_c: DE.ty_cst) =
match DT.definition ty_c with
| Some (
(Adt
{ cases = [| { cstr = { id_ty; _ } as cstr; dstrs; _ } |]; _ } as adt)
) ->
(* Records and adts that only have one case are treated in the same way,
and considered as records. *)
Nest.attach_orders [adt];
let tyvl = Cache.store_ty_vars_ret id_ty in
let lbs =
Array.fold_right (
fun c acc ->
match c with
| Some (DE.{ id_ty; _ } as id) ->
let pty = dty_to_ty id_ty in
(id, pty) :: acc
| _ ->
Fmt.failwith
"Unexpected null label for some field of the record type %a"
DE.Ty.Const.print ty_c

) dstrs []
in
let ty = Ty.trecord ~record_constr:cstr tyvl ty_c lbs in
Cache.store_ty ty_c ty

| Some (Adt { cases; _ } as adt) ->
Nest.attach_orders [adt];
let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in
Expand Down Expand Up @@ -654,29 +586,7 @@ let mk_term_decl ({ id_ty; path; tags; _ } as tcst: DE.term_cst) =
let mk_mr_ty_decls (tdl: DE.ty_cst list) =
let handle_ty_decl (ty: Ty.t) (tdef: DE.Ty.def option) =
match ty, tdef with
| Trecord { args; name; record_constr; _ },
Some (
Adt { cases = [| { dstrs; _ } |]; ty = ty_c; _ }
) ->
let lbs =
Array.fold_right (
fun c acc ->
match c with
| Some (DE.{ id_ty; _ } as id) ->
let pty = dty_to_ty id_ty in
(id, pty) :: acc
| _ ->
Fmt.failwith
"Unexpected null label for some field of the record type %a"
DE.Ty.Const.print ty_c
) dstrs []
in
let ty =
Ty.trecord ~record_constr args name lbs
in
Cache.store_ty ty_c ty

| Tadt (hs, tyl), Some (Adt { cases; ty = ty_c; _ }) ->
| Tadt (hs, tyl, _), Some (Adt { cases; ty = ty_c; _ }) ->
let cs =
Array.fold_right (
fun DE.{ cstr; dstrs; _ } accl ->
Expand All @@ -697,37 +607,15 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) =

| _ -> assert false
in
(* If there are adts in the list of type declarations then records are
converted to adts, because that's how it's done in the legacy typechecker.
But it might be more efficient not to do that. *)
let rev_tdefs, contains_adts =
List.fold_left (
fun (acc, ca) ty_c ->
match DT.definition ty_c with
| Some (Adt { record; cases; _ } as df)
when not record && Array.length cases > 1 ->
df :: acc, true
| Some (Adt _ as df) ->
df :: acc, ca
| Some Abstract | None ->
assert false
) ([], false) tdl
in
let rev_tdefs = List.rev_map (fun td -> Option.get @@ DT.definition td) tdl in
Nest.attach_orders rev_tdefs;
let rev_l =
List.fold_left (
fun acc tdef ->
match tdef with
| DE.Adt { cases; record; ty = ty_c; } as adt ->
| DE.Adt { cases; ty = ty_c; _ } as adt ->
let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in
let record_constr = cases.(0).cstr in
let ty =
if (record || Array.length cases = 1) && not contains_adts
then
Ty.trecord ~record_constr tyvl ty_c []
else
Ty.t_adt ty_c tyvl
in
let ty = Ty.t_adt ty_c tyvl in
Cache.store_ty ty_c ty;
(ty, Some adt) :: acc

Expand Down Expand Up @@ -953,14 +841,8 @@ let rec mk_expr
E.mk_term sy [] ty

| B.Constructor _ ->
begin match dty_to_ty term_ty with
| Trecord _ as ty ->
E.mk_record [] ty
| Tadt _ as ty ->
E.mk_constr tcst [] ty
| ty ->
Fmt.failwith "unexpected type %a@." Ty.print ty
end
let ty = dty_to_ty term_ty in
E.mk_constr tcst [] ty

| _ -> unsupported "Constant term %a" DE.Term.print term
end
Expand Down Expand Up @@ -1001,10 +883,7 @@ let rec mk_expr
let e = aux_mk_expr x in
let sy =
match Cache.find_ty adt with
| Trecord _ ->
Sy.Op (Sy.Access destr)
| Tadt _ ->
Sy.destruct destr
| Tadt _ -> Sy.destruct destr
| _ -> assert false
in
E.mk_term sy [e] ty
Expand Down Expand Up @@ -1035,11 +914,6 @@ let rec mk_expr
| Ty.Tadt _ ->
E.mk_tester cstr (aux_mk_expr x)

| Ty.Trecord _ ->
(* The typechecker allows only testers whose the
two arguments have the same type. Thus, we can always
replace the tester of a record by the true literal. *)
E.vrai
| _ -> assert false
end

Expand Down Expand Up @@ -1306,19 +1180,9 @@ let rec mk_expr

| B.Constructor _, _ ->
let ty = dty_to_ty term_ty in
begin match ty with
| Ty.Tadt _ ->
let l = List.map (fun t -> aux_mk_expr t) args in
E.mk_constr tcst l ty
| Ty.Trecord _ ->
let l = List.map (fun t -> aux_mk_expr t) args in
E.mk_record l ty
| _ ->
Fmt.failwith
"Constructor error: %a does not belong to a record nor an\
algebraic data type"
DE.Term.print app_term
end
let sy = Sy.constr tcst in
let l = List.map (fun t -> aux_mk_expr t) args in
E.mk_term sy l ty

| B.Coercion, [ x ] ->
begin match DT.view (DE.Term.ty x), DT.view term_ty with
Expand Down
7 changes: 2 additions & 5 deletions src/lib/reasoners/adt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ let constr_of_destr ty dest =
~module_name:"Adt" ~function_name:"constr_of_destr"
"ty = %a" Ty.print ty;
match ty with
| Ty.Tadt (s, params) ->
| Ty.Tadt (s, params, _) ->
begin
let cases = Ty.type_body s params in
try
Expand Down Expand Up @@ -182,7 +182,7 @@ module Shostak (X : ALIEN) = struct
in
let xs = List.rev sx in
match f, xs, ty with
| Sy.Op Sy.Constr hs, _, Ty.Tadt (name, params) ->
| Sy.Op Sy.Constr hs, _, Ty.Tadt (name, params, _) ->
let cases = Ty.type_body name params in
let case_hs =
try Ty.assoc_destrs hs cases with Not_found -> assert false
Expand All @@ -203,9 +203,6 @@ module Shostak (X : ALIEN) = struct
else sel_x, ctx (* canonization OK *)
*)

| Sy.Op Sy.Constr _, _, Ty.Trecord _ ->
Fmt.failwith "unexpected record constructor %a@." E.print t

| _ -> assert false

let hash x =
Expand Down
6 changes: 3 additions & 3 deletions src/lib/reasoners/adt_rel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ module Domain = struct

let unknown ty =
match ty with
| Ty.Tadt (name, params) ->
| Ty.Tadt (name, params, _) ->
(* Return the list of all the constructors of the type of [r]. *)
let cases = Ty.type_body name params in
let constrs =
Expand Down Expand Up @@ -462,7 +462,7 @@ let build_constr_eq r c =
match Th.embed r with
| Alien r ->
begin match X.type_info r with
| Ty.Tadt (name, params) as ty ->
| Ty.Tadt (name, params, _) as ty ->
let cases = Ty.type_body name params in
let ds =
try Ty.assoc_destrs c cases with Not_found -> assert false
Expand Down Expand Up @@ -548,7 +548,7 @@ let two = Numbers.Q.from_int 2
(* TODO: we should compute this reverse map in `Ty` and store it there. *)
let constr_of_destr ty d =
match ty with
| Ty.Tadt (name, params) ->
| Ty.Tadt (name, params, _) ->
begin
let cases = Ty.type_body name params in
try
Expand Down
20 changes: 10 additions & 10 deletions src/lib/reasoners/matching.ml
Original file line number Diff line number Diff line change
Expand Up @@ -314,11 +314,11 @@ module Make (X : Arg) : S with type theory = X.t = struct
| _ , [] -> l1
| _ -> List.fold_left (fun acc e -> e :: acc) l2 (List.rev l1)

let xs_modulo_records t { Ty.lbs; _ } =
List.rev
(* let xs_modulo_records t { Ty.lbs; _ } =
List.rev
(List.rev_map
(fun (hs, ty) ->
E.mk_term (Symbols.Op (Symbols.Access hs)) [t] ty) lbs)
E.mk_term (Symbols.Op (Symbols.Access hs)) [t] ty) lbs) *)

module SLE = (* sets of lists of terms *)
Set.Make(struct
Expand Down Expand Up @@ -413,13 +413,13 @@ module Make (X : Arg) : S with type theory = X.t = struct
(fun t l ->
let { E.f = f; xs = xs; ty = ty; _ } = E.term_view t in
if Symbols.compare f_pat f = 0 then xs::l
else
begin
match f_pat, ty with
| Symbols.Op (Symbols.Record), Ty.Trecord record ->
(xs_modulo_records t record) :: l
| _ -> l
end
else l
(* begin
match f_pat, ty with
| Symbols.Op (Symbols.Record), Ty.Trecord record ->
(xs_modulo_records t record) :: l
| _ -> l
end *)
) cl []
in
let cl = filter_classes mconf cl tbox in
Expand Down
Loading

0 comments on commit 4ac01de

Please sign in to comment.