From f7bd38d823d98ed35cb962e5349058cb74e3525d Mon Sep 17 00:00:00 2001 From: Christopher Pulte Date: Mon, 7 Aug 2023 22:01:56 +0100 Subject: [PATCH] (untested!) logic for pattern matching and constructors in free-variable and substitution functions --- backend/cn/indexTerms.ml | 78 ++++++++++++++++++++++++++++++++-------- backend/cn/terms.ml | 14 ++++---- 2 files changed, 70 insertions(+), 22 deletions(-) diff --git a/backend/cn/indexTerms.ml b/backend/cn/indexTerms.ml index 624ccab14..2b7c3e0ac 100644 --- a/backend/cn/indexTerms.ml +++ b/backend/cn/indexTerms.ml @@ -37,12 +37,11 @@ let pp ?(atomic=false) = let rec bound_by_pattern = function - | PSym s -> SymSet.singleton s - | PWild -> SymSet.empty + | PSym (s, bt) -> [(s, bt)] + | PWild _bt -> [] | PConstructor (_s, args) -> - List.fold_right SymSet.union - (List.map (fun (_id, pat) -> bound_by_pattern pat) args) - SymSet.empty + List.concat_map (fun (_id, pat) -> bound_by_pattern pat) args + let rec free_vars_ = function | Const _ -> SymSet.empty @@ -83,10 +82,16 @@ let rec free_vars_ = function | Apply (_pred, ts) -> free_vars_list ts | Let ((nm, t1), t2) -> SymSet.union (free_vars t1) (SymSet.remove nm (free_vars t2)) | Match (e, cases) -> - List.fold_right (fun (pattern, body) acc -> - SymSet.union acc (SymSet.diff (free_vars body) (bound_by_pattern pattern)) - ) cases (free_vars e) - | Constructor (_s, args) -> free_vars_list (List.map snd args) + let rec aux acc = function + | [] -> acc + | (pat, body) :: cases -> + let bound = SymSet.of_list (List.map fst (bound_by_pattern pat)) in + let more = SymSet.diff (free_vars body) bound in + aux (SymSet.union more acc) cases + in + aux (free_vars e) cases + | Constructor (_s, args) -> + free_vars_list (List.map snd args) and free_vars (IT (term_, _bt)) = free_vars_ term_ @@ -138,8 +143,18 @@ let rec fold_ f binders acc = function | Let ((nm, IT (t1_, bt)), t2) -> let acc' = fold f binders acc (IT (t1_, bt)) in fold f (binders @ [(nm, bt)]) acc' t2 - | Match _ -> failwith "todo" - | Constructor _ -> failwith "todo" + | Match (e, cases) -> + (* TODO: check this is good *) + let acc' = fold f binders acc e in + let rec aux acc = function + | [] -> acc + | (pat, body) :: cases -> + let acc' = fold f (binders @ bound_by_pattern pat) acc body in + aux acc' cases + in + aux acc' cases + | Constructor (s, args) -> + fold_list f binders acc (List.map snd args) and fold f binders acc (IT (term_, _bt)) = let acc' = fold_ f binders acc term_ in @@ -278,10 +293,43 @@ let rec subst (su : typed subst) (IT (it, bt)) = | Let ((name, t1), t2) -> let name, t2 = suitably_alpha_rename su.relevant (name, basetype t1) t2 in IT (Let ((name, subst su t1), subst su t2), bt) - | Match _ -> - failwith "todo" - | Constructor _ -> - failwith "todo" + | Match (e, cases) -> + let e = subst su e in + let cases = + List.map (fun (pat, body) -> + let rec aux (pat, body) = + match pat with + | PSym (s, bt) -> + let (s, body) = suitably_alpha_rename su.relevant (s, bt) body in + (PSym (s, bt), body) + | PWild bt -> + (PWild bt, body) + | PConstructor (s, args) -> + let args, body = + let ids, pats = List.split args in + let pats, body = aux_list (pats, body) in + (List.combine ids pats, body) + in + (PConstructor (s, args), body) + and aux_list (pats, body) = + match pats with + | [] -> ([], body) + | pat :: pats -> + let (pats, body) = aux_list (pats, body) in + let (pat, body) = aux (pat, body) in + (pat :: pats, body) + in + aux (pat, body) + ) cases + in + IT (Match (e, cases), bt) + | Constructor (s, args) -> + let args = + List.map (fun (id, e) -> + (id, subst su e) + ) args + in + IT (Constructor (s, args), bt) and alpha_rename (s, bt) body = let s' = Sym.fresh_same s in diff --git a/backend/cn/terms.ml b/backend/cn/terms.ml index 2655b36ef..baf8c87f6 100644 --- a/backend/cn/terms.ml +++ b/backend/cn/terms.ml @@ -47,10 +47,10 @@ type binop = | Subset [@@deriving eq, ord, show] -type pattern = - | PSym of Sym.t - | PWild - | PConstructor of Sym.t * (Id.t * pattern) list +type 'bt pattern = + | PSym of Sym.t * 'bt + | PWild of 'bt + | PConstructor of Sym.t * (Id.t * 'bt pattern) list [@@deriving eq, ord, map] (* over integers and reals *) @@ -93,7 +93,7 @@ type 'bt term_ = | MapDef of (Sym.t * 'bt) * 'bt term | Apply of Sym.t * ('bt term) list | Let of (Sym.t * 'bt term) * 'bt term - | Match of 'bt term * (pattern * 'bt term) list + | Match of 'bt term * ('bt pattern * 'bt term) list | Cast of BaseTypes.t * 'bt term and 'bt term = @@ -107,9 +107,9 @@ let compare = compare_term let rec pp_pattern = function - | PSym s -> + | PSym (s, _bt) -> Sym.pp s - | PWild -> + | PWild _bt -> underscore | PConstructor (c, args) -> Sym.pp c ^^^