From 85339a361ea56598d47f9de5cb37c7637b561adb Mon Sep 17 00:00:00 2001 From: Christopher Pulte Date: Fri, 11 Aug 2023 10:57:04 +0100 Subject: [PATCH] tidying up index terms --- backend/cn/check.ml | 1 + backend/cn/compile.ml | 6 ++- backend/cn/indexTerms.ml | 86 +++++++++++++++++++--------------------- backend/cn/mucore.ml | 39 +++++++++++------- backend/cn/solver.ml | 3 +- backend/cn/terms.ml | 28 ++++++------- backend/cn/typeErrors.ml | 18 ++++++--- backend/cn/wellTyped.ml | 37 ++++++++--------- 8 files changed, 117 insertions(+), 101 deletions(-) diff --git a/backend/cn/check.ml b/backend/cn/check.ml index 50c8db2e6..2eb136d7b 100644 --- a/backend/cn/check.ml +++ b/backend/cn/check.ml @@ -674,6 +674,7 @@ let rec check_pexpr (pe : 'bty mu_pexpr) ~(expect:BT.t) k (or_ [v1; v2])) end | M_PEapply_fun (fun_id, args) -> + (* TODO: this should be checking the base types *) let expect_args = Mucore.mu_fun_param_types fun_id in let@ () = if List.length expect_args = List.length args then return () else fail (fun _ -> {loc; msg = Number_arguments {has = List.length args; diff --git a/backend/cn/compile.ml b/backend/cn/compile.ml index 53c808956..c9f7d3213 100644 --- a/backend/cn/compile.ml +++ b/backend/cn/compile.ml @@ -591,7 +591,11 @@ module EffectfulTranslation = struct | CNExpr_list es_ -> let@ es = ListM.mapM self es_ in let item_bt = basetype (List.hd es) in - return (IT ((List es), SBT.List item_bt)) + let rec aux = function + | [] -> IT (Nil (SBT.to_basetype item_bt), SBT.List item_bt) + | x::xs -> IT (Cons (x, aux xs), SBT.List item_bt) + in + return (aux es) | CNExpr_memberof (e, xs) -> let@ e = self e in translate_member_access loc env e xs diff --git a/backend/cn/indexTerms.ml b/backend/cn/indexTerms.ml index 2b7c3e0ac..9f74006eb 100644 --- a/backend/cn/indexTerms.ml +++ b/backend/cn/indexTerms.ml @@ -36,9 +36,10 @@ let pp ?(atomic=false) = Terms.pp ~atomic -let rec bound_by_pattern = function - | PSym (s, bt) -> [(s, bt)] - | PWild _bt -> [] +let rec bound_by_pattern (Pat (pat_, bt)) = + match pat_ with + | PSym s -> [(s, bt)] + | PWild -> [] | PConstructor (_s, args) -> List.concat_map (fun (_id, pat) -> bound_by_pattern pat) args @@ -64,9 +65,8 @@ let rec free_vars_ = function | Cast (_cbt, t) -> free_vars t | MemberOffset (_tag, _id) -> SymSet.empty | ArrayOffset (_sct, t) -> free_vars t - | Nil -> SymSet.empty + | Nil _bt -> SymSet.empty | Cons (t1, t2) -> free_vars_list [t1; t2] - | List ts -> free_vars_list ts | Head t -> free_vars t | Tail t -> free_vars t | NthList (i, xs, d) -> free_vars_list [i; xs; d] @@ -124,9 +124,8 @@ let rec fold_ f binders acc = function | Cast (_cbt, t) -> fold f binders acc t | MemberOffset (_tag, _id) -> acc | ArrayOffset (_sct, t) -> fold f binders acc t - | Nil -> acc + | Nil _bt -> acc | Cons (t1, t2) -> fold_list f binders acc [t1; t2] - | List ts -> fold_list f binders acc ts | Head t -> fold f binders acc t | Tail t -> fold f binders acc t | NthList (i, xs, d) -> fold_list f binders acc [i; xs; d] @@ -167,7 +166,7 @@ and fold_list f binders acc xs = let acc' = fold f binders acc x in fold_list f binders acc' xs -let fold_subterms : 'a 'bt. ((Sym.t * 'bt) list -> 'a -> 'bt term -> 'a) -> 'a -> 'bt term -> 'a = +let fold_subterms : 'a. ((Sym.t * 'bt) list -> 'a -> 'bt term -> 'a) -> 'a -> 'bt term -> 'a = fun f acc t -> fold f [] acc t @@ -265,12 +264,10 @@ let rec subst (su : typed subst) (IT (it, bt)) = IT (Good (rt, subst su t), bt) | WrapI (ity, t) -> IT (WrapI (ity, subst su t), bt) - | Nil -> - IT (Nil, bt) + | Nil bt' -> + IT (Nil bt', bt) | Cons (it1,it2) -> IT (Cons (subst su it1, subst su it2), bt) - | List its -> - IT (List (map (subst su) its), bt) | Head it -> IT (Head (subst su it), bt) | Tail it -> @@ -295,33 +292,7 @@ let rec subst (su : typed subst) (IT (it, bt)) = IT (Let ((name, subst su t1), subst su t2), bt) | Match (e, cases) -> let e = subst su e in - let cases = - List.map (fun (pat, body) -> - let rec aux (pat, body) = - match pat with - | PSym (s, bt) -> - let (s, body) = suitably_alpha_rename su.relevant (s, bt) body in - (PSym (s, bt), body) - | PWild bt -> - (PWild bt, body) - | PConstructor (s, args) -> - let args, body = - let ids, pats = List.split args in - let pats, body = aux_list (pats, body) in - (List.combine ids pats, body) - in - (PConstructor (s, args), body) - and aux_list (pats, body) = - match pats with - | [] -> ([], body) - | pat :: pats -> - let (pats, body) = aux_list (pats, body) in - let (pat, body) = aux (pat, body) in - (pat :: pats, body) - in - aux (pat, body) - ) cases - in + let cases = List.map (subst_under_pattern su) cases in IT (Match (e, cases), bt) | Constructor (s, args) -> let args = @@ -341,6 +312,24 @@ and suitably_alpha_rename syms (s, bt) body = else (s, body) +and subst_under_pattern su (Pat (pat_, bt), body) = + match pat_ with + | PSym s -> + let (s, body) = suitably_alpha_rename su.relevant (s, bt) body in + (Pat (PSym s, bt), body) + | PWild -> + (Pat (PWild, bt), body) + | PConstructor (s, args) -> + let body, args = + fold_left_map (fun body (id, pat') -> + let pat', body = subst_under_pattern su (pat', body) in + (body, (id, pat')) + ) body args + in + (Pat (PConstructor (s, args), bt), body) + + + @@ -419,7 +408,8 @@ let rec split_and it = let rec is_const_val = function | IT (Const _, _) -> true - | IT (List xs, _) -> List.for_all is_const_val xs + | IT (Nil _, _) -> true + | IT (Cons (hd, tl), _) -> is_const_val hd && is_const_val tl | _ -> false @@ -638,18 +628,24 @@ let container_of_ (t, tag, member) = (sub_ (pointerToIntegerCast_ t, memberOffset_ (tag, member))) (* list_op *) -let nil_ ~item_bt = IT (Nil, BT.List item_bt) +let nil_ ~item_bt = IT (Nil item_bt, BT.List item_bt) let cons_ (it, it') = IT (Cons (it, it'), bt it') -let list_ ~item_bt its = IT (List its, BT.List item_bt) +let list_ ~item_bt its = + let rec aux = function + | [] -> IT (Nil item_bt, BT.List item_bt) + | x :: xs -> IT (Cons (x, aux xs), BT.List item_bt) + in + aux its + let head_ ~item_bt it = IT (Head it, item_bt) let tail_ it = IT (Tail it, bt it) let nthList_ (n, it, d) = IT (NthList (n, it, d), bt d) let array_to_list_ (arr, i, len) bt = IT (ArrayToList (arr, i, len), bt) -let rec dest_list it = match term it with - | Nil -> Some [] +let rec dest_list it = + match term it with + | Nil _bt -> Some [] | Cons (x, xs) -> Option.map (fun ys -> x :: ys) (dest_list xs) - | List xs -> Some xs (* TODO: maybe include Tail, if we ever actually use it? *) | _ -> None diff --git a/backend/cn/mucore.ml b/backend/cn/mucore.ml index ab5a0dc10..3cefee38a 100644 --- a/backend/cn/mucore.ml +++ b/backend/cn/mucore.ml @@ -307,22 +307,31 @@ let pp_function = function | M_are_compatible -> !^ "are_compatible" let evaluate_fun mu_fun args = - let args_it = List.map IT.term args in match mu_fun with - | M_params_length -> begin match args_it with - | [IT.List xs] -> Some (IT.int_ (List.length xs)) - | _ -> None - end - | M_params_nth -> begin match args_it, List.map IT.is_z args with - | [IT.List xs; _], [_; Some i] -> if Z.lt i (Z.of_int (List.length xs)) - then List.nth_opt xs (Z.to_int i) else None - | _ -> None - end - | M_are_compatible -> begin match List.map IT.is_const args with - | [Some (IT.CType_const ct1, _); Some (IT.CType_const ct2, _)] -> if Sctypes.equal ct1 ct2 - then Some (IT.bool_ true) else None - | _ -> None - end + | M_params_length -> + begin match args with + | [arg] -> + Option.bind (IT.dest_list arg) (fun xs -> + Some (IT.int_ (List.length xs))) + | _ -> None + end + | M_params_nth -> + begin match args with + | [arg1;arg2] -> + Option.bind (IT.dest_list arg1) (fun xs -> + Option.bind (IT.is_z arg2) (fun i -> + if Z.lt i (Z.of_int (List.length xs)) + then List.nth_opt xs (Z.to_int i) else None + )) + | _ -> None + end + | M_are_compatible -> + begin match List.map IT.is_const args with + | [Some (IT.CType_const ct1, _); Some (IT.CType_const ct2, _)] -> + if Sctypes.equal ct1 ct2 + then Some (IT.bool_ true) else None + | _ -> None + end type parse_ast_label_spec = diff --git a/backend/cn/solver.ml b/backend/cn/solver.ml index 094b92962..d2d49bdf0 100644 --- a/backend/cn/solver.ml +++ b/backend/cn/solver.ml @@ -616,7 +616,8 @@ module Translate = struct term (int_ (Option.get (Memory.member_offset decl member))) | ArrayOffset (ct, t) -> term (mul_ (int_ (Memory.size_of_ctype ct), t)) - | IT.List xs -> uninterp_term context (sort bt) it + | Nil _ -> uninterp_term context (sort bt) it + | Cons _ -> uninterp_term context (sort bt) it | NthList (i, xs, d) -> let args = List.map term [i; xs; d] in let nm = bt_suffix_name "nth_list" bt in diff --git a/backend/cn/terms.ml b/backend/cn/terms.ml index baf8c87f6..fe63f5dd5 100644 --- a/backend/cn/terms.ml +++ b/backend/cn/terms.ml @@ -47,10 +47,13 @@ type binop = | Subset [@@deriving eq, ord, show] -type 'bt pattern = - | PSym of Sym.t * 'bt - | PWild of 'bt +type 'bt pattern_ = + | PSym of Sym.t + | PWild | PConstructor of Sym.t * (Id.t * 'bt pattern) list + +and 'bt pattern = + | Pat of 'bt pattern_ * 'bt [@@deriving eq, ord, map] (* over integers and reals *) @@ -76,9 +79,8 @@ type 'bt term_ = | Constructor of Sym.t * (Id.t * 'bt term) list | MemberOffset of Sym.t * Id.t | ArrayOffset of Sctypes.t (*element ct*) * 'bt term (*index*) - | Nil + | Nil of BaseTypes.t | Cons of 'bt term * 'bt term - | List of 'bt term list | Head of 'bt term | Tail of 'bt term | NthList of 'bt term * 'bt term * 'bt term @@ -90,7 +92,7 @@ type 'bt term_ = | MapConst of BaseTypes.t * 'bt term | MapSet of 'bt term * 'bt term * 'bt term | MapGet of 'bt term * 'bt term - | MapDef of (Sym.t * 'bt) * 'bt term + | MapDef of (Sym.t * BaseTypes.t) * 'bt term | Apply of Sym.t * ('bt term) list | Let of (Sym.t * 'bt term) * 'bt term | Match of 'bt term * ('bt pattern * 'bt term) list @@ -106,10 +108,11 @@ let compare = compare_term -let rec pp_pattern = function - | PSym (s, _bt) -> +let rec pp_pattern (Pat (pat_, _bt)) = + match pat_ with + | PSym s -> Sym.pp s - | PWild _bt -> + | PWild -> underscore | PConstructor (c, args) -> Sym.pp c ^^^ @@ -282,12 +285,10 @@ let pp : 'bt 'a. ?atomic:bool -> ?f:('bt term -> Pp.doc -> Pp.doc) -> 'bt term - c_app !^"hd" [aux false o1] | Tail (o1) -> c_app !^"tl" [aux false o1] - | Nil -> - brackets empty + | Nil bt -> + !^"nil" ^^ angles (BaseTypes.pp bt) | Cons (t1,t2) -> mparens (aux true t1 ^^ colon ^^ colon ^^ aux true t2) - | List its -> - mparens (brackets (separate_map (comma ^^ space) (aux false) its)) | NthList (n, xs, d) -> c_app !^"nth_list" [aux false n; aux false xs; aux false d] | ArrayToList (arr, i, len) -> @@ -397,7 +398,6 @@ let rec dtree (IT (it_, bt)) = | (ArrayOffset (ty, t)) -> Dnode (pp_ctor "ArrayOffset", [Dleaf (Sctypes.pp ty); dtree t]) | (Representable (ty, t)) -> Dnode (pp_ctor "Representable", [Dleaf (Sctypes.pp ty); dtree t]) | (Good (ty, t)) -> Dnode (pp_ctor "Good", [Dleaf (Sctypes.pp ty); dtree t]) - | List its -> Dnode (pp_ctor "List", (List.map dtree its)) | (Aligned a) -> Dnode (pp_ctor "Aligned", [dtree a.t; dtree a.align]) | (MapConst (bt, t)) -> Dnode (pp_ctor "MapConst", [dtree t]) | (MapSet (t1, t2, t3)) -> Dnode (pp_ctor "MapSet", [dtree t1; dtree t2; dtree t3]) diff --git a/backend/cn/typeErrors.ml b/backend/cn/typeErrors.ml index 72bf41906..8d7508cbf 100644 --- a/backend/cn/typeErrors.ml +++ b/backend/cn/typeErrors.ml @@ -115,7 +115,6 @@ type message = | NIA : {it: IT.t; hint : string; ctxt : Context.t} -> message | TooBigExponent : {it: IT.t; ctxt : Context.t} -> message | NegativeExponent : {it: IT.t; ctxt : Context.t} -> message - | Polymorphic_it : 'bt IndexTerms.term -> message | Write_value_unrepresentable of {ct: Sctypes.t; location: IT.t; value: IT.t; ctxt : Context.t; model : Solver.model_with_q } | Int_unrepresentable of {value : IT.t; ict : Sctypes.t; ctxt : Context.t; model : Solver.model_with_q} | Unproven_constraint of {constr : LC.t; info : info; ctxt : Context.t; model : Solver.model_with_q; trace : Trace.t} @@ -129,6 +128,10 @@ type message = | Parser of Cerb_frontend.Errors.cparser_cause + | Empty_pattern + | Missing_pattern of Pp.document + | Duplicate_pattern of Loc.t + type type_error = { loc : Locations.t; @@ -313,10 +316,6 @@ let pp_message te = !^("Exponent must be non-negative") in { short; descr = Some descr; state = None; trace = None } - | Polymorphic_it it -> - let short = !^"Type inference failed" in - let descr = !^"Polymorphic index term" ^^^ squotes (IndexTerms.pp it) in - { short; descr = Some descr; state = None; trace = None } | Write_value_unrepresentable {ct; location; value; ctxt; model} -> let short = !^"Write value not representable at type" ^^^ @@ -382,6 +381,15 @@ let pp_message te = | Parser err -> let short = !^(Cerb_frontend.Pp_errors.string_of_cparser_cause err) in { short; descr = None; state = None; trace = None } + | Empty_pattern -> + let short = !^"Empty match expression." in + { short; descr = None; state = None; trace = None } + | Missing_pattern p' -> + let short = !^"Missing pattern" ^^^ squotes p' ^^ dot in + { short; descr = None; state = None; trace = None } + | Duplicate_pattern loc -> + let short = !^"Duplicate pattern (already matched at" ^^^ Loc.pp loc ^^ !^")" in + { short; descr = None; state = None; trace = None } type t = type_error diff --git a/backend/cn/wellTyped.ml b/backend/cn/wellTyped.ml index e3016ffee..c7b1dd355 100644 --- a/backend/cn/wellTyped.ml +++ b/backend/cn/wellTyped.ml @@ -132,7 +132,7 @@ module WIT = struct let eval = Simplify.IndexTerms.eval let rec infer = - fun loc ((IT (it, _)) as it_) -> + fun loc (IT (it, _)) -> match it with | Sym s -> let@ is_a = bound_a s in @@ -511,18 +511,12 @@ module WIT = struct let@ () = WCT.is_ct loc (Integer ity) in let@ t = check loc Integer t in return (IT (WrapI (ity, t), BT.Integer)) - | Nil -> - fail (fun _ -> {loc; msg = Polymorphic_it it_}) + | Nil bt -> + return (IT (Nil bt, BT.List bt)) | Cons (t1,t2) -> let@ t1 = infer loc t1 in let@ t2 = check loc (List (IT.bt t1)) t2 in return (IT (Cons (t1, t2),BT.List (IT.bt t1))) - | List [] -> - fail (fun _ -> {loc; msg = Polymorphic_it it_}) - | List (t :: ts) -> - let@ t = infer loc t in - let@ ts = ListM.mapM (check loc (IT.bt t)) ts in - return (IT (List (t :: ts),BT.List (IT.bt t))) | Head t -> let@ t = infer loc t in let@ bt = ensure_list_type loc t in @@ -586,22 +580,25 @@ module WIT = struct let@ t2 = infer loc t2 in return (IT (Let ((name, t1), t2), IT.bt t2)) end - | Match _ -> failwith "todo" + | Match (e, []) -> + fail (fun _ -> {loc; msg = Empty_pattern}) + | Match (e, case::cases) -> + (* let@ e = infer loc e in *) + (* let case = *) + (* let pat, body = case in *) + (* in *) + failwith "asd" | Constructor _ -> failwith "todo" and check = fun loc ls it -> let@ () = WLS.is_ls loc ls in - match it, ls with - | IT (Nil, _), List bt -> - return (IT (Nil, BT.List bt)) - | _, _ -> - let@ it = infer loc it in - if LS.equal ls (IT.bt it) then - return it - else - let expected = Pp.plain (LS.pp ls) in - fail (illtyped_index_term loc it (IT.bt it) expected) + let@ it = infer loc it in + if LS.equal ls (IT.bt it) then + return it + else + let expected = Pp.plain (LS.pp ls) in + fail (illtyped_index_term loc it (IT.bt it) expected)