diff --git a/language/sail.ott b/language/sail.ott index fe9ae3fae..8712c60b9 100644 --- a/language/sail.ott +++ b/language/sail.ott @@ -218,16 +218,17 @@ kind :: 'K_' ::= nexp :: 'Nexp_' ::= {{ com numeric expression, of kind Int }} {{ aux _ l }} - | id :: :: id {{ com abbreviation identifier }} - | kid :: :: var {{ com variable }} - | num :: :: constant {{ com constant }} - | id ( nexp1 , ... , nexpn ) :: :: app {{ com app }} - | nexp1 * nexp2 :: :: times {{ com product }} - | nexp1 + nexp2 :: :: sum {{ com sum }} - | nexp1 - nexp2 :: :: minus {{ com subtraction }} - | 2 ^ nexp :: :: exp {{ com exponential }} - | - nexp :: :: neg {{ com unary negation}} - | ( nexp ) :: S :: paren {{ ichlo [[nexp]] }} + | id :: :: id {{ com abbreviation identifier }} + | kid :: :: var {{ com variable }} + | num :: :: constant {{ com constant }} + | id ( nexp1 , ... , nexpn ) :: :: app {{ com app }} + | if n_constraint then nexp1 else nexp2 :: :: if {{ com if-then-else }} + | nexp1 * nexp2 :: :: times {{ com product }} + | nexp1 + nexp2 :: :: sum {{ com sum }} + | nexp1 - nexp2 :: :: minus {{ com subtraction }} + | 2 ^ nexp :: :: exp {{ com exponential }} + | - nexp :: :: neg {{ com unary negation}} + | ( nexp ) :: S :: paren {{ ichlo [[nexp]] }} order :: 'Ord_' ::= {{ com vector order specifications, of kind Order }} diff --git a/src/lib/ast_util.ml b/src/lib/ast_util.ml index 4fb232a21..5c50fda06 100644 --- a/src/lib/ast_util.ml +++ b/src/lib/ast_util.ml @@ -268,41 +268,142 @@ module Id = struct | Id_aux (Operator _, _), Id_aux (Id _, _) -> 1 end +let lex_ord f g x1 x2 y1 y2 = match f x1 x2 with 0 -> g y1 y2 | n -> n + +let rec nexp_compare (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) = + let lex_ord (c1, c2) = if c1 = 0 then c2 else c1 in + match (nexp1, nexp2) with + | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2 + | Nexp_var kid1, Nexp_var kid2 -> Kid.compare kid1 kid2 + | Nexp_constant c1, Nexp_constant c2 -> Big_int.compare c1 c2 + | Nexp_app (op1, args1), Nexp_app (op2, args2) -> + let lex1 = Id.compare op1 op2 in + let lex2 = List.length args1 - List.length args2 in + let lex3 = if lex2 = 0 then List.fold_left2 (fun l n1 n2 -> lex_ord (l, compare n1 n2)) 0 args1 args2 else 0 in + lex_ord (lex1, lex_ord (lex2, lex3)) + | Nexp_times (n1a, n1b), Nexp_times (n2a, n2b) + | Nexp_sum (n1a, n1b), Nexp_sum (n2a, n2b) + | Nexp_minus (n1a, n1b), Nexp_minus (n2a, n2b) -> + lex_ord (compare n1a n2a, compare n1b n2b) + | Nexp_exp n1, Nexp_exp n2 -> compare n1 n2 + | Nexp_neg n1, Nexp_neg n2 -> compare n1 n2 + | Nexp_if (i1, t1, e1), Nexp_if (i2, t2, e2) -> + let lex1 = nc_compare i1 i2 in + let lex2 = nexp_compare t1 t2 in + let lex3 = nexp_compare e1 e2 in + lex_ord (lex1, lex_ord (lex2, lex3)) + | Nexp_constant _, _ -> -1 + | _, Nexp_constant _ -> 1 + | Nexp_id _, _ -> -1 + | _, Nexp_id _ -> 1 + | Nexp_var _, _ -> -1 + | _, Nexp_var _ -> 1 + | Nexp_neg _, _ -> -1 + | _, Nexp_neg _ -> 1 + | Nexp_exp _, _ -> -1 + | _, Nexp_exp _ -> 1 + | Nexp_minus _, _ -> -1 + | _, Nexp_minus _ -> 1 + | Nexp_sum _, _ -> -1 + | _, Nexp_sum _ -> 1 + | Nexp_times _, _ -> -1 + | _, Nexp_times _ -> 1 + | Nexp_if _, _ -> -1 + | _, Nexp_if _ -> 1 + +and nc_compare (NC_aux (nc1, _)) (NC_aux (nc2, _)) = + match (nc1, nc2) with + | NC_id id1, NC_id id2 -> Id.compare id1 id2 + | NC_equal (n1, n2), NC_equal (n3, n4) + | NC_bounded_ge (n1, n2), NC_bounded_ge (n3, n4) + | NC_bounded_gt (n1, n2), NC_bounded_gt (n3, n4) + | NC_bounded_le (n1, n2), NC_bounded_le (n3, n4) + | NC_bounded_lt (n1, n2), NC_bounded_lt (n3, n4) + | NC_not_equal (n1, n2), NC_not_equal (n3, n4) -> + lex_ord nexp_compare nexp_compare n1 n3 n2 n4 + | NC_set (n1, s1), NC_set (n2, s2) -> lex_ord nexp_compare (Util.compare_list Nat_big_num.compare) n1 n2 s1 s2 + | NC_or (nc1, nc2), NC_or (nc3, nc4) | NC_and (nc1, nc2), NC_and (nc3, nc4) -> + lex_ord nc_compare nc_compare nc1 nc3 nc2 nc4 + | NC_app (f1, args1), NC_app (f2, args2) -> lex_ord Id.compare (Util.compare_list typ_arg_compare) f1 f2 args1 args2 + | NC_var v1, NC_var v2 -> Kid.compare v1 v2 + | NC_true, NC_true | NC_false, NC_false -> 0 + | NC_equal _, _ -> -1 + | _, NC_equal _ -> 1 + | NC_bounded_ge _, _ -> -1 + | _, NC_bounded_ge _ -> 1 + | NC_bounded_gt _, _ -> -1 + | _, NC_bounded_gt _ -> 1 + | NC_bounded_le _, _ -> -1 + | _, NC_bounded_le _ -> 1 + | NC_bounded_lt _, _ -> -1 + | _, NC_bounded_lt _ -> 1 + | NC_not_equal _, _ -> -1 + | _, NC_not_equal _ -> 1 + | NC_set _, _ -> -1 + | _, NC_set _ -> 1 + | NC_or _, _ -> -1 + | _, NC_or _ -> 1 + | NC_and _, _ -> -1 + | _, NC_and _ -> 1 + | NC_app _, _ -> -1 + | _, NC_app _ -> 1 + | NC_var _, _ -> -1 + | _, NC_var _ -> 1 + | NC_true, _ -> -1 + | _, NC_true -> 1 + | NC_id _, _ -> -1 + | _, NC_id _ -> 1 + +and typ_compare (Typ_aux (t1, _)) (Typ_aux (t2, _)) = + match (t1, t2) with + | Typ_internal_unknown, Typ_internal_unknown -> 0 + | Typ_id id1, Typ_id id2 -> Id.compare id1 id2 + | Typ_var kid1, Typ_var kid2 -> Kid.compare kid1 kid2 + | Typ_fn (ts1, t2), Typ_fn (ts3, t4) -> ( + match Util.compare_list typ_compare ts1 ts3 with 0 -> typ_compare t2 t4 | n -> n + ) + | Typ_bidir (t1, t2), Typ_bidir (t3, t4) -> ( + match typ_compare t1 t3 with 0 -> typ_compare t2 t4 | n -> n + ) + | Typ_tuple ts1, Typ_tuple ts2 -> Util.compare_list typ_compare ts1 ts2 + | Typ_exist (ks1, nc1, t1), Typ_exist (ks2, nc2, t2) -> ( + match Util.compare_list KOpt.compare ks1 ks2 with + | 0 -> ( + match nc_compare nc1 nc2 with 0 -> typ_compare t1 t2 | n -> n + ) + | n -> n + ) + | Typ_app (id1, ts1), Typ_app (id2, ts2) -> ( + match Id.compare id1 id2 with 0 -> Util.compare_list typ_arg_compare ts1 ts2 | n -> n + ) + | Typ_internal_unknown, _ -> -1 + | _, Typ_internal_unknown -> 1 + | Typ_id _, _ -> -1 + | _, Typ_id _ -> 1 + | Typ_var _, _ -> -1 + | _, Typ_var _ -> 1 + | Typ_fn _, _ -> -1 + | _, Typ_fn _ -> 1 + | Typ_bidir _, _ -> -1 + | _, Typ_bidir _ -> 1 + | Typ_tuple _, _ -> -1 + | _, Typ_tuple _ -> 1 + | Typ_exist _, _ -> -1 + | _, Typ_exist _ -> 1 + +and typ_arg_compare (A_aux (ta1, _)) (A_aux (ta2, _)) = + match (ta1, ta2) with + | A_nexp n1, A_nexp n2 -> nexp_compare n1 n2 + | A_typ t1, A_typ t2 -> typ_compare t1 t2 + | A_bool nc1, A_bool nc2 -> nc_compare nc1 nc2 + | A_nexp _, _ -> -1 + | _, A_nexp _ -> 1 + | A_typ _, _ -> -1 + | _, A_typ _ -> 1 + module Nexp = struct type t = nexp - let rec compare (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) = - let lex_ord (c1, c2) = if c1 = 0 then c2 else c1 in - match (nexp1, nexp2) with - | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2 - | Nexp_var kid1, Nexp_var kid2 -> Kid.compare kid1 kid2 - | Nexp_constant c1, Nexp_constant c2 -> Big_int.compare c1 c2 - | Nexp_app (op1, args1), Nexp_app (op2, args2) -> - let lex1 = Id.compare op1 op2 in - let lex2 = List.length args1 - List.length args2 in - let lex3 = if lex2 = 0 then List.fold_left2 (fun l n1 n2 -> lex_ord (l, compare n1 n2)) 0 args1 args2 else 0 in - lex_ord (lex1, lex_ord (lex2, lex3)) - | Nexp_times (n1a, n1b), Nexp_times (n2a, n2b) - | Nexp_sum (n1a, n1b), Nexp_sum (n2a, n2b) - | Nexp_minus (n1a, n1b), Nexp_minus (n2a, n2b) -> - lex_ord (compare n1a n2a, compare n1b n2b) - | Nexp_exp n1, Nexp_exp n2 -> compare n1 n2 - | Nexp_neg n1, Nexp_neg n2 -> compare n1 n2 - | Nexp_constant _, _ -> -1 - | _, Nexp_constant _ -> 1 - | Nexp_id _, _ -> -1 - | _, Nexp_id _ -> 1 - | Nexp_var _, _ -> -1 - | _, Nexp_var _ -> 1 - | Nexp_neg _, _ -> -1 - | _, Nexp_neg _ -> 1 - | Nexp_exp _, _ -> -1 - | _, Nexp_exp _ -> 1 - | Nexp_minus _, _ -> -1 - | _, Nexp_minus _ -> 1 - | Nexp_sum _, _ -> -1 - | _, Nexp_sum _ -> 1 - | Nexp_times _, _ -> -1 - | _, Nexp_times _ -> 1 + let compare = nexp_compare end module Bindings = Map.Make (Id) @@ -323,6 +424,7 @@ let rec is_nexp_constant (Nexp_aux (nexp, _)) = | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> is_nexp_constant n1 && is_nexp_constant n2 | Nexp_exp n | Nexp_neg n -> is_nexp_constant n | Nexp_app (_, nexps) -> List.for_all is_nexp_constant nexps + | Nexp_if (i, t, e) -> false let int_of_nexp_opt nexp = match nexp with Nexp_aux (Nexp_constant i, _) -> Some i | _ -> None @@ -891,8 +993,10 @@ and string_of_nexp_aux = function | Nexp_app (id, nexps) -> string_of_id id ^ "(" ^ string_of_list ", " string_of_nexp nexps ^ ")" | Nexp_exp n -> "2 ^ " ^ string_of_nexp n | Nexp_neg n -> "- " ^ string_of_nexp n + | Nexp_if (i, t, e) -> + "(if" ^ string_of_n_constraint i ^ " then " ^ string_of_nexp t ^ " else " ^ string_of_nexp e ^ ")" -let rec string_of_typ = function Typ_aux (typ, _) -> string_of_typ_aux typ +and string_of_typ = function Typ_aux (typ, _) -> string_of_typ_aux typ and string_of_typ_aux = function | Typ_internal_unknown -> "" @@ -1214,98 +1318,6 @@ let rec get_scattered_enum_clauses id = function | _ :: defs -> get_scattered_enum_clauses id defs | [] -> [] -let lex_ord f g x1 x2 y1 y2 = match f x1 x2 with 0 -> g y1 y2 | n -> n - -let rec nc_compare (NC_aux (nc1, _)) (NC_aux (nc2, _)) = - match (nc1, nc2) with - | NC_id id1, NC_id id2 -> Id.compare id1 id2 - | NC_equal (n1, n2), NC_equal (n3, n4) - | NC_bounded_ge (n1, n2), NC_bounded_ge (n3, n4) - | NC_bounded_gt (n1, n2), NC_bounded_gt (n3, n4) - | NC_bounded_le (n1, n2), NC_bounded_le (n3, n4) - | NC_bounded_lt (n1, n2), NC_bounded_lt (n3, n4) - | NC_not_equal (n1, n2), NC_not_equal (n3, n4) -> - lex_ord Nexp.compare Nexp.compare n1 n3 n2 n4 - | NC_set (n1, s1), NC_set (n2, s2) -> lex_ord Nexp.compare (Util.compare_list Nat_big_num.compare) n1 n2 s1 s2 - | NC_or (nc1, nc2), NC_or (nc3, nc4) | NC_and (nc1, nc2), NC_and (nc3, nc4) -> - lex_ord nc_compare nc_compare nc1 nc3 nc2 nc4 - | NC_app (f1, args1), NC_app (f2, args2) -> lex_ord Id.compare (Util.compare_list typ_arg_compare) f1 f2 args1 args2 - | NC_var v1, NC_var v2 -> Kid.compare v1 v2 - | NC_true, NC_true | NC_false, NC_false -> 0 - | NC_equal _, _ -> -1 - | _, NC_equal _ -> 1 - | NC_bounded_ge _, _ -> -1 - | _, NC_bounded_ge _ -> 1 - | NC_bounded_gt _, _ -> -1 - | _, NC_bounded_gt _ -> 1 - | NC_bounded_le _, _ -> -1 - | _, NC_bounded_le _ -> 1 - | NC_bounded_lt _, _ -> -1 - | _, NC_bounded_lt _ -> 1 - | NC_not_equal _, _ -> -1 - | _, NC_not_equal _ -> 1 - | NC_set _, _ -> -1 - | _, NC_set _ -> 1 - | NC_or _, _ -> -1 - | _, NC_or _ -> 1 - | NC_and _, _ -> -1 - | _, NC_and _ -> 1 - | NC_app _, _ -> -1 - | _, NC_app _ -> 1 - | NC_var _, _ -> -1 - | _, NC_var _ -> 1 - | NC_true, _ -> -1 - | _, NC_true -> 1 - | NC_id _, _ -> -1 - | _, NC_id _ -> 1 - -and typ_compare (Typ_aux (t1, _)) (Typ_aux (t2, _)) = - match (t1, t2) with - | Typ_internal_unknown, Typ_internal_unknown -> 0 - | Typ_id id1, Typ_id id2 -> Id.compare id1 id2 - | Typ_var kid1, Typ_var kid2 -> Kid.compare kid1 kid2 - | Typ_fn (ts1, t2), Typ_fn (ts3, t4) -> ( - match Util.compare_list typ_compare ts1 ts3 with 0 -> typ_compare t2 t4 | n -> n - ) - | Typ_bidir (t1, t2), Typ_bidir (t3, t4) -> ( - match typ_compare t1 t3 with 0 -> typ_compare t2 t4 | n -> n - ) - | Typ_tuple ts1, Typ_tuple ts2 -> Util.compare_list typ_compare ts1 ts2 - | Typ_exist (ks1, nc1, t1), Typ_exist (ks2, nc2, t2) -> ( - match Util.compare_list KOpt.compare ks1 ks2 with - | 0 -> ( - match nc_compare nc1 nc2 with 0 -> typ_compare t1 t2 | n -> n - ) - | n -> n - ) - | Typ_app (id1, ts1), Typ_app (id2, ts2) -> ( - match Id.compare id1 id2 with 0 -> Util.compare_list typ_arg_compare ts1 ts2 | n -> n - ) - | Typ_internal_unknown, _ -> -1 - | _, Typ_internal_unknown -> 1 - | Typ_id _, _ -> -1 - | _, Typ_id _ -> 1 - | Typ_var _, _ -> -1 - | _, Typ_var _ -> 1 - | Typ_fn _, _ -> -1 - | _, Typ_fn _ -> 1 - | Typ_bidir _, _ -> -1 - | _, Typ_bidir _ -> 1 - | Typ_tuple _, _ -> -1 - | _, Typ_tuple _ -> 1 - | Typ_exist _, _ -> -1 - | _, Typ_exist _ -> 1 - -and typ_arg_compare (A_aux (ta1, _)) (A_aux (ta2, _)) = - match (ta1, ta2) with - | A_nexp n1, A_nexp n2 -> Nexp.compare n1 n2 - | A_typ t1, A_typ t2 -> typ_compare t1 t2 - | A_bool nc1, A_bool nc2 -> nc_compare nc1 nc2 - | A_nexp _, _ -> -1 - | _, A_nexp _ -> 1 - | A_typ _, _ -> -1 - | _, A_typ _ -> 1 - let is_typ_arg_nexp = function A_aux (A_typ _, _) -> true | _ -> false let is_typ_arg_typ = function A_aux (A_typ _, _) -> true | _ -> false @@ -1329,18 +1341,6 @@ end module TypMap = Map.Make (Typ) -let rec nexp_frees (Nexp_aux (nexp, l)) = - match nexp with - | Nexp_id _ -> raise (Reporting.err_typ l "Unimplemented Nexp_id in nexp_frees") - | Nexp_var kid -> KidSet.singleton kid - | Nexp_constant _ -> KidSet.empty - | Nexp_times (n1, n2) -> KidSet.union (nexp_frees n1) (nexp_frees n2) - | Nexp_sum (n1, n2) -> KidSet.union (nexp_frees n1) (nexp_frees n2) - | Nexp_minus (n1, n2) -> KidSet.union (nexp_frees n1) (nexp_frees n2) - | Nexp_exp n -> nexp_frees n - | Nexp_neg n -> nexp_frees n - | Nexp_app (_, nexps) -> List.fold_left KidSet.union KidSet.empty (List.map nexp_frees nexps) - let rec lexp_to_exp (LE_aux (lexp_aux, annot)) = let rewrap e_aux = E_aux (e_aux, annot) in match lexp_aux with @@ -1418,8 +1418,9 @@ let rec kopts_of_nexp (Nexp_aux (nexp, _)) = | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> KOptSet.union (kopts_of_nexp n1) (kopts_of_nexp n2) | Nexp_exp n | Nexp_neg n -> kopts_of_nexp n | Nexp_app (_, nexps) -> List.fold_left KOptSet.union KOptSet.empty (List.map kopts_of_nexp nexps) + | Nexp_if (i, t, e) -> KOptSet.union (kopts_of_constraint i) (KOptSet.union (kopts_of_nexp t) (kopts_of_nexp e)) -let rec kopts_of_constraint (NC_aux (nc, _)) = +and kopts_of_constraint (NC_aux (nc, _)) = match nc with | NC_equal (nexp1, nexp2) | NC_bounded_ge (nexp1, nexp2) @@ -1463,8 +1464,9 @@ let rec tyvars_of_nexp (Nexp_aux (nexp, _)) = | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> KidSet.union (tyvars_of_nexp n1) (tyvars_of_nexp n2) | Nexp_exp n | Nexp_neg n -> tyvars_of_nexp n | Nexp_app (_, nexps) -> List.fold_left KidSet.union KidSet.empty (List.map tyvars_of_nexp nexps) + | Nexp_if (i, t, e) -> KidSet.union (tyvars_of_constraint i) (KidSet.union (tyvars_of_nexp t) (tyvars_of_nexp e)) -let rec tyvars_of_constraint (NC_aux (nc, _)) = +and tyvars_of_constraint (NC_aux (nc, _)) = match nc with | NC_equal (nexp1, nexp2) | NC_bounded_ge (nexp1, nexp2) @@ -1739,10 +1741,11 @@ let rec locate_nexp f (Nexp_aux (nexp_aux, l)) = | Nexp_minus (nexp1, nexp2) -> Nexp_minus (locate_nexp f nexp1, locate_nexp f nexp2) | Nexp_exp nexp -> Nexp_exp (locate_nexp f nexp) | Nexp_neg nexp -> Nexp_neg (locate_nexp f nexp) + | Nexp_if (i, t, e) -> Nexp_if (locate_nc f i, locate_nexp f t, locate_nexp f e) in Nexp_aux (nexp_aux, f l) -let rec locate_nc f (NC_aux (nc_aux, l)) = +and locate_nc f (NC_aux (nc_aux, l)) = let nc_aux = match nc_aux with | NC_id id -> NC_id (locate_id f id) @@ -1937,8 +1940,9 @@ and nexp_subst_aux sv subst = function | Nexp_app (id, nexps) -> Nexp_app (id, List.map (nexp_subst sv subst) nexps) | Nexp_exp nexp -> Nexp_exp (nexp_subst sv subst nexp) | Nexp_neg nexp -> Nexp_neg (nexp_subst sv subst nexp) + | Nexp_if (i, t, e) -> Nexp_if (constraint_subst sv subst i, nexp_subst sv subst t, nexp_subst sv subst e) -let rec constraint_subst sv subst (NC_aux (nc, l)) = NC_aux (constraint_subst_aux l sv subst nc, l) +and constraint_subst sv subst (NC_aux (nc, l)) = NC_aux (constraint_subst_aux l sv subst nc, l) and constraint_subst_aux l sv subst = function | NC_id id -> NC_id id @@ -2002,10 +2006,10 @@ let typquant_subst_kid_aux sv subst = function let typquant_subst_kid sv subst (TypQ_aux (typq, l)) = TypQ_aux (typquant_subst_kid_aux sv subst typq, l) -let subst_kids_nexp substs nexp = - let rec s_snexp substs (Nexp_aux (ne, l) as nexp) = +let subst_kids_nexp, subst_kids_nc, subst_kids_typ, subst_kids_typ_arg = + let rec subst_kids_nexp substs (Nexp_aux (ne, l) as nexp) = let re ne = Nexp_aux (ne, l) in - let s_snexp = s_snexp substs in + let s_snexp = subst_kids_nexp substs in match ne with | Nexp_var v -> ( try KBindings.find v substs with Not_found -> nexp @@ -2017,11 +2021,8 @@ let subst_kids_nexp substs nexp = | Nexp_exp ne -> re (Nexp_exp (s_snexp ne)) | Nexp_neg ne -> re (Nexp_neg (s_snexp ne)) | Nexp_app (id, args) -> re (Nexp_app (id, List.map s_snexp args)) - in - s_snexp substs nexp - -let subst_kids_nc, subst_kids_typ, subst_kids_typ_arg = - let rec subst_kids_nc substs (NC_aux (nc, l) as n_constraint) = + | Nexp_if (i, t, e) -> re (Nexp_if (subst_kids_nc substs i, s_snexp t, s_snexp e)) + and subst_kids_nc substs (NC_aux (nc, l) as n_constraint) = let snexp nexp = subst_kids_nexp substs nexp in let snc nc = subst_kids_nc substs nc in let re nc = NC_aux (nc, l) in @@ -2057,7 +2058,7 @@ let subst_kids_nc, subst_kids_typ, subst_kids_typ_arg = | A_typ t -> A_aux (A_typ (s_styp substs t), l) | A_bool nc -> A_aux (A_bool (subst_kids_nc substs nc), l) in - (subst_kids_nc, s_styp, s_starg) + (subst_kids_nexp, subst_kids_nc, s_styp, s_starg) let before p1 p2 = let open Lexing in diff --git a/src/lib/ast_util.mli b/src/lib/ast_util.mli index 83afe890c..36062c49e 100644 --- a/src/lib/ast_util.mli +++ b/src/lib/ast_util.mli @@ -488,7 +488,6 @@ val prepend_kid : string -> kid -> kid (** {1 Misc functions} *) -val nexp_frees : nexp -> KidSet.t val nexp_identical : nexp -> nexp -> bool val is_nexp_constant : nexp -> bool val int_of_nexp_opt : nexp -> Big_int.num option diff --git a/src/lib/callgraph.ml b/src/lib/callgraph.ml index df85e0d8d..ebe1403a1 100644 --- a/src/lib/callgraph.ml +++ b/src/lib/callgraph.ml @@ -140,6 +140,7 @@ and nexp_ids' (Nexp_aux (aux, _)) = | Nexp_var _ | Nexp_constant _ -> IdSet.empty | Nexp_exp n | Nexp_neg n -> nexp_ids' n | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> IdSet.union (nexp_ids' n1) (nexp_ids' n2) + | Nexp_if (i, t, e) -> IdSet.union (constraint_ids' i) (IdSet.union (nexp_ids' t) (nexp_ids' e)) and typ_ids' (Typ_aux (aux, _)) = match aux with diff --git a/src/lib/chunk_ast.ml b/src/lib/chunk_ast.ml index 24cbb0018..3f91e9fb6 100644 --- a/src/lib/chunk_ast.ml +++ b/src/lib/chunk_ast.ml @@ -595,6 +595,12 @@ let rec chunk_atyp comments chunks (ATyp_aux (aux, l)) = | ATyp_neg arg -> let arg_chunks = rec_chunk_atyp arg in Queue.add (Unary ("-", arg_chunks)) chunks + | ATyp_if (i, t, e) -> + let if_format = { then_brace = false; else_brace = false } in + let i_chunks = rec_chunk_atyp i in + let t_chunks = rec_chunk_atyp t in + let e_chunks = rec_chunk_atyp e in + Queue.add (If_then_else (if_format, i_chunks, t_chunks, e_chunks)) chunks | ATyp_inc -> Queue.add (Atom "inc") chunks | ATyp_dec -> Queue.add (Atom "dec") chunks | ATyp_fn (lhs, rhs, _) -> diff --git a/src/lib/constraint.ml b/src/lib/constraint.ml index ad820605b..0cd9d1efd 100644 --- a/src/lib/constraint.ml +++ b/src/lib/constraint.ml @@ -237,8 +237,8 @@ let to_smt l abstract vars constr = sfun "to_int" [sfun "^" [Atom "2"; exp]] end | Nexp_neg nexp -> sfun "-" [smt_nexp nexp] - in - let rec smt_constraint (NC_aux (aux, _) : n_constraint) : sexpr = + | Nexp_if (i, t, e) -> sfun "ite" [smt_constraint i; smt_nexp t; smt_nexp e] + and smt_constraint (NC_aux (aux, _) : n_constraint) : sexpr = match aux with | NC_id id -> Atom (Util.zencode_string (string_of_id id)) | NC_equal (nexp1, nexp2) -> sfun "=" [smt_nexp nexp1; smt_nexp nexp2] diff --git a/src/lib/initial_check.ml b/src/lib/initial_check.ml index f3068f112..e306877d2 100644 --- a/src/lib/initial_check.ml +++ b/src/lib/initial_check.ml @@ -400,6 +400,7 @@ and to_ast_nexp ctx atyp = | P.ATyp_minus (t1, t2) -> Nexp_aux (Nexp_minus (to_ast_nexp ctx t1, to_ast_nexp ctx t2), l) | P.ATyp_app (id, ts) -> Nexp_aux (Nexp_app (to_ast_id ctx id, List.map (to_ast_nexp ctx) ts), l) | P.ATyp_parens atyp -> to_ast_nexp ctx atyp + | P.ATyp_if (i, t, e) -> Nexp_aux (Nexp_if (to_ast_constraint ctx i, to_ast_nexp ctx t, to_ast_nexp ctx e), l) | _ -> raise (Reporting.err_typ l "Invalid numeric expression in type") and to_ast_bitfield_index_nexp ctx atyp = @@ -1447,6 +1448,7 @@ let initial_ctx = ("implicit", [Some K_int]); ("itself", [Some K_int]); ("not", [Some K_bool]); + ("ite", [Some K_bool; Some K_int; Some K_int]); ]; kinds = KBindings.empty; scattereds = Bindings.empty; diff --git a/src/lib/monomorphise.ml b/src/lib/monomorphise.ml index 662fd9470..ede7ac88a 100644 --- a/src/lib/monomorphise.ml +++ b/src/lib/monomorphise.ml @@ -295,6 +295,7 @@ let rec size_nvars_nexp (Nexp_aux (ne, _)) = | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> size_nvars_nexp n1 @ size_nvars_nexp n2 | Nexp_exp n | Nexp_neg n -> size_nvars_nexp n | Nexp_app (_, args) -> List.concat (List.map size_nvars_nexp args) + | Nexp_if (_, t, e) -> size_nvars_nexp t @ size_nvars_nexp e (* Given a type for a constructor, work out which refinements we ought to produce *) (* TODO collision avoidance *) @@ -2018,7 +2019,7 @@ module Analysis = struct KidSet.fold check kids dempty let deps_of_nexp l kid_deps arg_deps nexp = - let kids = nexp_frees nexp in + let kids = tyvars_of_nexp nexp in deps_of_tyvars l kid_deps arg_deps kids let rec deps_of_nc kid_deps (NC_aux (nc, l)) = @@ -2614,7 +2615,7 @@ module Analysis = struct let qs = match tq with TypQ_no_forall -> [] | TypQ_tq qs -> qs in let eqn_instantiations = Type_check.instantiate_simple_equations qs in let eqn_kid_deps = - KBindings.map (function A_aux (A_nexp nexp, _) -> Some (nexp_frees nexp) | _ -> None) eqn_instantiations + KBindings.map (function A_aux (A_nexp nexp, _) -> Some (tyvars_of_nexp nexp) | _ -> None) eqn_instantiations in let arg i pat = let rec aux (P_aux (p, (l, annot))) = @@ -4385,6 +4386,7 @@ module ToplevelNexpRewrites = struct | Nexp_exp n -> "exp_" ^ mangle_nexp n | Nexp_neg n -> "neg_" ^ mangle_nexp n | Nexp_app (id, args) -> string_of_id id ^ "_" ^ String.concat "_" (List.map mangle_nexp args) + | Nexp_if (_, t, e) -> mangle_nexp t ^ "_or_" ^ mangle_nexp e in (* TODO: I'd like to add a # to distinguish it from user-provided names, but the rewriter currently uses them as a hint that they're not printable in diff --git a/src/lib/parse_ast.ml b/src/lib/parse_ast.ml index ad9b4fb28..deb48d6d0 100644 --- a/src/lib/parse_ast.ml +++ b/src/lib/parse_ast.ml @@ -145,6 +145,7 @@ type atyp_aux = | ATyp_wild | ATyp_tuple of atyp list (* Tuple type *) | ATyp_app of id * atyp list (* type constructor application *) + | ATyp_if of atyp * atyp * atyp | ATyp_exist of kinded_id list * atyp * atyp | ATyp_parens of atyp diff --git a/src/lib/parser.mly b/src/lib/parser.mly index 79f4f7049..84b8439ba 100644 --- a/src/lib/parser.mly +++ b/src/lib/parser.mly @@ -386,6 +386,12 @@ typ_no_caret: $endpos) } typ: + | If_; cond_t = infix_typ; Then; then_t = infix_typ; Else; else_t = infix_typ + { mk_typ (ATyp_if (cond_t, then_t, else_t)) $startpos $endpos } + | t = infix_typ + { t } + +infix_typ: | prefix = prefix_typ_op; x = postfix_typ; xs = list(op = op; prefix = prefix_typ_op; y = postfix_typ { (IT_op op, $startpos(op), $endpos(op)) :: prefix @ y }) diff --git a/src/lib/pretty_print_sail.ml b/src/lib/pretty_print_sail.ml index a2635bf34..d31961985 100644 --- a/src/lib/pretty_print_sail.ml +++ b/src/lib/pretty_print_sail.ml @@ -111,7 +111,7 @@ let rec doc_typ_pat (TP_aux (tpat_aux, _)) = | TP_var kid -> doc_kid kid | TP_app (f, tpats) -> doc_id f ^^ parens (separate_map (comma ^^ space) doc_typ_pat tpats) -let doc_nexp nexp = +let rec doc_nexp nexp = let rec atomic_nexp (Nexp_aux (n_aux, _) as nexp) = match n_aux with | Nexp_constant c -> string (Big_int.to_string c) @@ -125,14 +125,18 @@ let doc_nexp nexp = | _ -> parens (nexp0 nexp) and nexp0 (Nexp_aux (n_aux, _) as nexp) = match n_aux with - | Nexp_sum (n1, Nexp_aux (Nexp_neg n2, _)) | Nexp_minus (n1, n2) -> separate space [nexp0 n1; string "-"; nexp1 n2] - | Nexp_sum (n1, Nexp_aux (Nexp_constant c, _)) when Big_int.less c Big_int.zero -> - separate space [nexp0 n1; string "-"; doc_int (Big_int.abs c)] - | Nexp_sum (n1, n2) -> separate space [nexp0 n1; string "+"; nexp1 n2] + | Nexp_if (i, t, e) -> separate space [string "if"; doc_nc i; string "then"; nexp1 t; string "else"; nexp1 e] | _ -> nexp1 nexp and nexp1 (Nexp_aux (n_aux, _) as nexp) = - match n_aux with Nexp_times (n1, n2) -> separate space [nexp1 n1; string "*"; nexp2 n2] | _ -> nexp2 nexp + match n_aux with + | Nexp_sum (n1, Nexp_aux (Nexp_neg n2, _)) | Nexp_minus (n1, n2) -> separate space [nexp1 n1; string "-"; nexp2 n2] + | Nexp_sum (n1, Nexp_aux (Nexp_constant c, _)) when Big_int.less c Big_int.zero -> + separate space [nexp1 n1; string "-"; doc_int (Big_int.abs c)] + | Nexp_sum (n1, n2) -> separate space [nexp1 n1; string "+"; nexp2 n2] + | _ -> nexp2 nexp and nexp2 (Nexp_aux (n_aux, _) as nexp) = + match n_aux with Nexp_times (n1, n2) -> separate space [nexp2 n1; string "*"; nexp3 n2] | _ -> nexp3 nexp + and nexp3 (Nexp_aux (n_aux, _) as nexp) = match n_aux with | Nexp_neg n -> separate space [string "-"; atomic_nexp n] | Nexp_exp n -> separate space [string "2"; string "^"; atomic_nexp n] @@ -140,7 +144,7 @@ let doc_nexp nexp = in nexp0 nexp -let rec doc_nc nc = +and doc_nc nc = let nc_op op n1 n2 = separate space [doc_nexp n1; string op; doc_nexp n2] in let rec atomic_nc (NC_aux (nc_aux, _) as nc) = match nc_aux with diff --git a/src/lib/specialize.ml b/src/lib/specialize.ml index e7f1d56b4..e8a3bb71e 100644 --- a/src/lib/specialize.ml +++ b/src/lib/specialize.ml @@ -177,9 +177,9 @@ let string_of_instantiation instantiation = | Nexp_app (id, nexps) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_nexp nexps ^ ")" | Nexp_exp n -> "2 ^ " ^ string_of_nexp n | Nexp_neg n -> "- " ^ string_of_nexp n - in - - let rec string_of_typ = function Typ_aux (typ, l) -> string_of_typ_aux typ + | Nexp_if (i, t, e) -> + "(if " ^ string_of_n_constraint i ^ " then " ^ string_of_nexp t ^ " else " ^ string_of_nexp e ^ ")" + and string_of_typ = function Typ_aux (typ, l) -> string_of_typ_aux typ and string_of_typ_aux = function | Typ_id id -> string_of_id id | Typ_var kid -> kid_name (mk_kopt K_type kid) diff --git a/src/lib/type_check.ml b/src/lib/type_check.ml index 25d39aa02..1d8c80de6 100644 --- a/src/lib/type_check.ml +++ b/src/lib/type_check.ml @@ -108,6 +108,7 @@ let rec orig_nexp (Nexp_aux (nexp, l)) = | Nexp_minus (n1, n2) -> rewrap (Nexp_minus (orig_nexp n1, orig_nexp n2)) | Nexp_exp n -> rewrap (Nexp_exp (orig_nexp n)) | Nexp_neg n -> rewrap (Nexp_neg (orig_nexp n)) + | Nexp_if (i, t, e) -> rewrap (Nexp_if (i, orig_nexp t, orig_nexp e)) | _ -> rewrap nexp let is_list (Typ_aux (typ_aux, _)) = @@ -459,32 +460,6 @@ let prove pos env nc = (* 3. Unification *) (**************************************************************************) -let rec nexp_frees ?(exs = KidSet.empty) (Nexp_aux (nexp, _)) = - match nexp with - | Nexp_id _ -> KidSet.empty - | Nexp_var kid -> KidSet.singleton kid - | Nexp_constant _ -> KidSet.empty - | Nexp_times (n1, n2) -> KidSet.union (nexp_frees ~exs n1) (nexp_frees ~exs n2) - | Nexp_sum (n1, n2) -> KidSet.union (nexp_frees ~exs n1) (nexp_frees ~exs n2) - | Nexp_minus (n1, n2) -> KidSet.union (nexp_frees ~exs n1) (nexp_frees ~exs n2) - | Nexp_app (_, ns) -> List.fold_left KidSet.union KidSet.empty (List.map (fun n -> nexp_frees ~exs n) ns) - | Nexp_exp n -> nexp_frees ~exs n - | Nexp_neg n -> nexp_frees ~exs n - -let rec nexp_identical (Nexp_aux (nexp1, _)) (Nexp_aux (nexp2, _)) = - match (nexp1, nexp2) with - | Nexp_id v1, Nexp_id v2 -> Id.compare v1 v2 = 0 - | Nexp_var kid1, Nexp_var kid2 -> Kid.compare kid1 kid2 = 0 - | Nexp_constant c1, Nexp_constant c2 -> Big_int.equal c1 c2 - | Nexp_times (n1a, n1b), Nexp_times (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b - | Nexp_sum (n1a, n1b), Nexp_sum (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b - | Nexp_minus (n1a, n1b), Nexp_minus (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b - | Nexp_exp n1, Nexp_exp n2 -> nexp_identical n1 n2 - | Nexp_neg n1, Nexp_neg n2 -> nexp_identical n1 n2 - | Nexp_app (f1, args1), Nexp_app (f2, args2) when List.length args1 = List.length args2 -> - Id.compare f1 f2 = 0 && List.for_all2 nexp_identical args1 args2 - | _, _ -> false - let rec nc_identical (NC_aux (nc1, _)) (NC_aux (nc2, _)) = match (nc1, nc2) with | NC_equal (n1a, n1b), NC_equal (n2a, n2b) -> nexp_identical n1a n2a && nexp_identical n1b n2b @@ -656,7 +631,7 @@ and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_au ^ string_of_list ", " string_of_kid (KidSet.elements goals) ) ); - if KidSet.is_empty (KidSet.inter (nexp_frees nexp1) goals) then begin + if KidSet.is_empty (KidSet.inter (tyvars_of_nexp nexp1) goals) then begin if prove __POS__ env (NC_aux (NC_equal (nexp1, nexp2), Parse_ast.Unknown)) then KBindings.empty else unify_error l ("Integer expressions " ^ string_of_nexp nexp1 ^ " and " ^ string_of_nexp nexp2 ^ " are not equal") @@ -671,13 +646,13 @@ and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_au | _ -> unify_error l "Unification error" end | Nexp_sum (n1a, n1b) -> - if KidSet.is_empty (nexp_frees n1b) then unify_nexp l env goals n1a (nminus nexp2 n1b) - else if KidSet.is_empty (nexp_frees n1a) then unify_nexp l env goals n1b (nminus nexp2 n1a) + if KidSet.is_empty (tyvars_of_nexp n1b) then unify_nexp l env goals n1a (nminus nexp2 n1b) + else if KidSet.is_empty (tyvars_of_nexp n1a) then unify_nexp l env goals n1b (nminus nexp2 n1a) else begin match nexp_aux2 with | Nexp_sum (n2a, n2b) -> - if KidSet.is_empty (nexp_frees n2a) then unify_nexp l env goals n2b (nminus nexp1 n2a) - else if KidSet.is_empty (nexp_frees n2a) then unify_nexp l env goals n2a (nminus nexp1 n2b) + if KidSet.is_empty (tyvars_of_nexp n2a) then unify_nexp l env goals n2b (nminus nexp1 n2a) + else if KidSet.is_empty (tyvars_of_nexp n2a) then unify_nexp l env goals n2a (nminus nexp1 n2b) else merge_uvars env l (unify_nexp l env goals n1a n2a) (unify_nexp l env goals n1b n2b) | _ -> unify_error l @@ -686,7 +661,7 @@ and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_au ) end | Nexp_minus (n1a, n1b) -> - if KidSet.is_empty (nexp_frees n1b) then unify_nexp l env goals n1a (nsum nexp2 n1b) + if KidSet.is_empty (tyvars_of_nexp n1b) then unify_nexp l env goals n1a (nsum nexp2 n1b) else unify_error l ("Cannot unify minus Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) | Nexp_times (n1a, n1b) -> @@ -704,11 +679,11 @@ and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_au let valid n c = prove __POS__ env (nc_eq (napp (mk_id "mod") [n; c]) (nint 0)) && prove __POS__ env (nc_neq c (nint 0)) in - (*if KidSet.is_empty (nexp_frees n1b) && valid nexp2 n1b then + (*if KidSet.is_empty (tyvars_of_nexp n1b) && valid nexp2 n1b then unify_nexp l env goals n1a (napp (mk_id "div") [nexp2; n1b]) - else if KidSet.is_empty (nexp_frees n1a) && valid nexp2 n1a then + else if KidSet.is_empty (tyvars_of_nexp n1a) && valid nexp2 n1a then unify_nexp l env goals n1b (napp (mk_id "div") [nexp2; n1a]) *) - if KidSet.is_empty (nexp_frees n1a) then begin + if KidSet.is_empty (tyvars_of_nexp n1a) then begin match nexp_aux2 with | Nexp_times (n2a, n2b) when prove __POS__ env (NC_aux (NC_equal (n1a, n2a), Parse_ast.Unknown)) -> unify_nexp l env goals n1b n2b @@ -723,7 +698,7 @@ and unify_nexp l env goals (Nexp_aux (nexp_aux1, _) as nexp1) (Nexp_aux (nexp_au unify_nexp l env goals n1b (napp (mk_id "div") [nexp2; n1a]) | _ -> unify_error l ("Cannot unify Int expression " ^ string_of_nexp nexp1 ^ " with " ^ string_of_nexp nexp2) end - else if KidSet.is_empty (nexp_frees n1b) then begin + else if KidSet.is_empty (tyvars_of_nexp n1b) then begin match nexp_aux2 with | Nexp_times (n2a, n2b) when prove __POS__ env (NC_aux (NC_equal (n1b, n2b), Parse_ast.Unknown)) -> unify_nexp l env goals n1a n2a @@ -873,29 +848,12 @@ let destruct_atom_bool env typ = | Typ_aux (Typ_app (f, [A_aux (A_bool nc, _)]), _) when string_of_id f = "atom_bool" -> Some nc | _ -> None -(* The kid_order function takes a set of Int-kinded kids, and returns - a list of those kids in the order they appear in a type, as well as - a set containing all the kids that did not occur in the type. We - only care about Int-kinded kids because those are the only type - that can appear in an existential. *) - -let rec kid_order_nexp kind_map (Nexp_aux (aux, _)) = - match aux with - | Nexp_var kid when KBindings.mem kid kind_map -> - ([mk_kopt (unaux_kind (KBindings.find kid kind_map)) kid], KBindings.remove kid kind_map) - | Nexp_var _ | Nexp_id _ | Nexp_constant _ -> ([], kind_map) - | Nexp_exp nexp | Nexp_neg nexp -> kid_order_nexp kind_map nexp - | Nexp_times (nexp1, nexp2) | Nexp_sum (nexp1, nexp2) | Nexp_minus (nexp1, nexp2) -> - let ord, kids = kid_order_nexp kind_map nexp1 in - let ord', kids = kid_order_nexp kids nexp2 in - (ord @ ord', kids) - | Nexp_app (_, nexps) -> - List.fold_left - (fun (ord, kids) nexp -> - let ord', kids = kid_order_nexp kids nexp in - (ord @ ord', kids) - ) - ([], kind_map) nexps +(* The kid_order function takes a set of Int-kinded type variables, + and returns a list of those type variables in the order they appear + in a type, as well as a set containing all the kids that did not + occur in the type. We only care about Int-kinded and Bool-kinded + type variables because those are the only type that can appear in + an existential. *) let rec kid_order kind_map (Typ_aux (aux, l) as typ) = match aux with @@ -926,6 +884,29 @@ and kid_order_arg kind_map (A_aux (aux, _)) = | A_nexp nexp -> kid_order_nexp kind_map nexp | A_bool nc -> kid_order_constraint kind_map nc +and kid_order_nexp kind_map (Nexp_aux (aux, _)) = + match aux with + | Nexp_var kid when KBindings.mem kid kind_map -> + ([mk_kopt (unaux_kind (KBindings.find kid kind_map)) kid], KBindings.remove kid kind_map) + | Nexp_var _ | Nexp_id _ | Nexp_constant _ -> ([], kind_map) + | Nexp_exp nexp | Nexp_neg nexp -> kid_order_nexp kind_map nexp + | Nexp_times (nexp1, nexp2) | Nexp_sum (nexp1, nexp2) | Nexp_minus (nexp1, nexp2) -> + let ord, kids = kid_order_nexp kind_map nexp1 in + let ord', kids = kid_order_nexp kids nexp2 in + (ord @ ord', kids) + | Nexp_app (_, nexps) -> + List.fold_left + (fun (ord, kids) nexp -> + let ord', kids = kid_order_nexp kids nexp in + (ord @ ord', kids) + ) + ([], kind_map) nexps + | Nexp_if (i, t, e) -> + let ord, kind_map = kid_order_constraint kind_map i in + let ord', kind_map = kid_order_nexp kind_map t in + let ord'', kind_map = kid_order_nexp kind_map e in + (ord @ ord' @ ord'', kind_map) + and kid_order_constraint kind_map (NC_aux (aux, _)) = match aux with | NC_var kid when KBindings.mem kid kind_map -> @@ -1055,10 +1036,10 @@ let rec subtyp l env typ1 typ2 = let env = add_existential l (List.map (mk_kopt K_int) kids1) nc1 env in let env = add_typ_vars l - (List.map (mk_kopt K_int) (KidSet.elements (KidSet.inter (nexp_frees nexp2) (KidSet.of_list kids2)))) + (List.map (mk_kopt K_int) (KidSet.elements (KidSet.inter (tyvars_of_nexp nexp2) (KidSet.of_list kids2)))) env in - let kids2 = KidSet.elements (KidSet.diff (KidSet.of_list kids2) (nexp_frees nexp2)) in + let kids2 = KidSet.elements (KidSet.diff (KidSet.of_list kids2) (tyvars_of_nexp nexp2)) in if not (kids2 = []) then typ_error l ("Universally quantified constraint generated: " ^ Util.string_of_list ", " string_of_kid kids2) else (); @@ -1220,10 +1201,40 @@ let rec rewrite_sizeof' l env (Nexp_aux (aux, _) as nexp) = let exp1 = rewrite_sizeof' l env nexp1 in let exp2 = rewrite_sizeof' l env nexp2 in mk_exp (E_app (mk_id "emod_int", [exp1; exp2])) + | Nexp_if (i, t, e) -> + let i = rewrite_nc env i in + let t = rewrite_sizeof' l env t in + let e = rewrite_sizeof' l env e in + mk_exp (E_if (i, t, e)) | Nexp_id id when Env.is_abstract_typ id env -> mk_exp (E_sizeof nexp) | Nexp_app _ | Nexp_id _ -> typ_error l ("Cannot re-write sizeof(" ^ string_of_nexp nexp ^ ")") -let rewrite_sizeof l env nexp = +and rewrite_nc env (NC_aux (nc_aux, l)) = mk_exp ~loc:l (rewrite_nc_aux l env nc_aux) + +and rewrite_nc_aux l env = function + | NC_bounded_ge (n1, n2) -> E_app_infix (rewrite_sizeof l env n1, mk_id ">=", rewrite_sizeof l env n2) + | NC_bounded_gt (n1, n2) -> E_app_infix (rewrite_sizeof l env n1, mk_id ">", rewrite_sizeof l env n2) + | NC_bounded_le (n1, n2) -> E_app_infix (rewrite_sizeof l env n1, mk_id "<=", rewrite_sizeof l env n2) + | NC_bounded_lt (n1, n2) -> E_app_infix (rewrite_sizeof l env n1, mk_id "<", rewrite_sizeof l env n2) + | NC_equal (n1, n2) -> E_app_infix (rewrite_sizeof l env n1, mk_id "==", rewrite_sizeof l env n2) + | NC_not_equal (n1, n2) -> E_app_infix (rewrite_sizeof l env n1, mk_id "!=", rewrite_sizeof l env n2) + | NC_and (nc1, nc2) -> E_app_infix (rewrite_nc env nc1, mk_id "&", rewrite_nc env nc2) + | NC_or (nc1, nc2) -> E_app_infix (rewrite_nc env nc1, mk_id "|", rewrite_nc env nc2) + | NC_false -> E_lit (mk_lit L_false) + | NC_true -> E_lit (mk_lit L_true) + | NC_set (_, []) -> E_lit (mk_lit L_false) + | NC_set (nexp, int :: ints) -> + let nexp_eq int = nc_eq nexp (nconstant int) in + unaux_exp (rewrite_nc env (List.fold_left (fun nc int -> nc_or nc (nexp_eq int)) (nexp_eq int) ints)) + | NC_app (f, [A_aux (A_bool nc, _)]) when string_of_id f = "not" -> E_app (mk_id "not_bool", [rewrite_nc env nc]) + | NC_app (f, args) -> unaux_exp (rewrite_nc env (Env.expand_constraint_synonyms env (mk_nc (NC_app (f, args))))) + | NC_var v -> + (* Would be better to translate change E_sizeof to take a kid, then rewrite to E_sizeof *) + E_id (id_of_kid v) + | NC_id id when Env.is_abstract_typ id env -> E_constraint (NC_aux (NC_id id, l)) + | NC_id id -> typ_error l ("Cannot re-write constraint(" ^ string_of_id id ^ ")") + +and rewrite_sizeof l env nexp = try rewrite_sizeof' l env nexp with No_simple_rewrite -> let locals = Env.get_locals env |> Bindings.bindings in @@ -1242,33 +1253,6 @@ let rewrite_sizeof l env nexp = ) end -let rec rewrite_nc env (NC_aux (nc_aux, l)) = mk_exp ~loc:l (rewrite_nc_aux l env nc_aux) - -and rewrite_nc_aux l env = - let mk_exp exp = mk_exp ~loc:l exp in - function - | NC_bounded_ge (n1, n2) -> E_app_infix (mk_exp (E_sizeof n1), mk_id ">=", mk_exp (E_sizeof n2)) - | NC_bounded_gt (n1, n2) -> E_app_infix (mk_exp (E_sizeof n1), mk_id ">", mk_exp (E_sizeof n2)) - | NC_bounded_le (n1, n2) -> E_app_infix (mk_exp (E_sizeof n1), mk_id "<=", mk_exp (E_sizeof n2)) - | NC_bounded_lt (n1, n2) -> E_app_infix (mk_exp (E_sizeof n1), mk_id "<", mk_exp (E_sizeof n2)) - | NC_equal (n1, n2) -> E_app_infix (mk_exp (E_sizeof n1), mk_id "==", mk_exp (E_sizeof n2)) - | NC_not_equal (n1, n2) -> E_app_infix (mk_exp (E_sizeof n1), mk_id "!=", mk_exp (E_sizeof n2)) - | NC_and (nc1, nc2) -> E_app_infix (rewrite_nc env nc1, mk_id "&", rewrite_nc env nc2) - | NC_or (nc1, nc2) -> E_app_infix (rewrite_nc env nc1, mk_id "|", rewrite_nc env nc2) - | NC_false -> E_lit (mk_lit L_false) - | NC_true -> E_lit (mk_lit L_true) - | NC_set (_, []) -> E_lit (mk_lit L_false) - | NC_set (nexp, int :: ints) -> - let nexp_eq int = nc_eq nexp (nconstant int) in - unaux_exp (rewrite_nc env (List.fold_left (fun nc int -> nc_or nc (nexp_eq int)) (nexp_eq int) ints)) - | NC_app (f, [A_aux (A_bool nc, _)]) when string_of_id f = "not" -> E_app (mk_id "not_bool", [rewrite_nc env nc]) - | NC_app (f, args) -> unaux_exp (rewrite_nc env (Env.expand_constraint_synonyms env (mk_nc (NC_app (f, args))))) - | NC_var v -> - (* Would be better to translate change E_sizeof to take a kid, then rewrite to E_sizeof *) - E_id (id_of_kid v) - | NC_id id when Env.is_abstract_typ id env -> E_constraint (NC_aux (NC_id id, l)) - | NC_id id -> typ_error l ("Cannot re-write constraint(" ^ string_of_id id ^ ")") - let can_be_undefined ~at:l env typ = let rec check (Typ_aux (aux, _)) = match aux with @@ -1376,7 +1360,7 @@ let instantiate_simple_equations = let rec find_eqs kid (NC_aux (nc, _)) = match nc with | NC_equal (Nexp_aux (Nexp_var kid', _), nexp) - when Kid.compare kid kid' == 0 && not (KidSet.mem kid (nexp_frees nexp)) -> + when Kid.compare kid kid' == 0 && not (KidSet.mem kid (tyvars_of_nexp nexp)) -> [arg_nexp nexp] | NC_and (nexp1, nexp2) -> find_eqs kid nexp1 @ find_eqs kid nexp2 | _ -> [] diff --git a/src/lib/type_env.ml b/src/lib/type_env.ml index 921d29dbe..631eb7414 100644 --- a/src/lib/type_env.ml +++ b/src/lib/type_env.ml @@ -679,6 +679,10 @@ module Well_formedness = struct wf_nexp exs env nexp2 | Nexp_exp nexp -> wf_nexp exs env nexp (* MAYBE: Could put restrictions on what is allowed here *) | Nexp_neg nexp -> wf_nexp exs env nexp + | Nexp_if (i, t, e) -> + wf_constraint exs env i; + wf_nexp exs env t; + wf_nexp exs env e and wf_constraint exs env (NC_aux (nc_aux, l) as nc) = wf_debug "constraint" string_of_n_constraint nc exs; @@ -798,6 +802,8 @@ and expand_nexp_synonyms env (Nexp_aux (aux, l) as nexp) = | Nexp_neg nexp -> Nexp_aux (Nexp_neg (expand_nexp_synonyms env nexp), l) | Nexp_var kid -> Nexp_aux (Nexp_var kid, l) | Nexp_constant n -> Nexp_aux (Nexp_constant n, l) + | Nexp_if (i, t, e) -> + Nexp_aux (Nexp_if (expand_constraint_synonyms env i, expand_nexp_synonyms env t, expand_nexp_synonyms env e), l) and expand_synonyms env (Typ_aux (typ, l)) = match typ with diff --git a/src/lib/type_internal.ml b/src/lib/type_internal.ml index f089440fe..618a67fe5 100644 --- a/src/lib/type_internal.ml +++ b/src/lib/type_internal.ml @@ -140,6 +140,7 @@ and unloc_nexp_aux = function | Nexp_minus (nexp1, nexp2) -> Nexp_minus (unloc_nexp nexp1, unloc_nexp nexp2) | Nexp_exp nexp -> Nexp_exp (unloc_nexp nexp) | Nexp_neg nexp -> Nexp_neg (unloc_nexp nexp) + | Nexp_if (i, t, e) -> Nexp_if (unloc_n_constraint i, unloc_nexp t, unloc_nexp e) and unloc_nexp = function Nexp_aux (nexp_aux, _) -> Nexp_aux (unloc_nexp_aux nexp_aux, Parse_ast.Unknown) @@ -238,8 +239,10 @@ let rec nexp_power_variables (Nexp_aux (aux, _)) = | Nexp_id _ | Nexp_var _ | Nexp_constant _ -> KidSet.empty | Nexp_app (_, ns) -> List.fold_left KidSet.union KidSet.empty (List.map nexp_power_variables ns) | Nexp_exp n -> tyvars_of_nexp n + | Nexp_if (i, t, e) -> + KidSet.union (constraint_power_variables i) (KidSet.union (nexp_power_variables t) (nexp_power_variables e)) -let constraint_power_variables nc = +and constraint_power_variables nc = List.fold_left KidSet.union KidSet.empty (List.map nexp_power_variables (constraint_nexps nc)) let ex_counter = ref 0 diff --git a/src/sail_coq_backend/pretty_print_coq.ml b/src/sail_coq_backend/pretty_print_coq.ml index b5844193f..de09d1334 100644 --- a/src/sail_coq_backend/pretty_print_coq.ml +++ b/src/sail_coq_backend/pretty_print_coq.ml @@ -418,8 +418,9 @@ let rec count_nexp_vars (Nexp_aux (nexp, _)) = | Nexp_times (n1, n2) | Nexp_sum (n1, n2) | Nexp_minus (n1, n2) -> merge_kid_count (count_nexp_vars n1) (count_nexp_vars n2) | Nexp_exp n | Nexp_neg n -> count_nexp_vars n + | Nexp_if (i, t, e) -> merge_kid_count (count_nc_vars i) (merge_kid_count (count_nexp_vars t) (count_nexp_vars e)) -let rec count_nc_vars (NC_aux (nc, _)) = +and count_nc_vars (NC_aux (nc, _)) = let count_arg (A_aux (arg, _)) = match arg with A_bool nc -> count_nc_vars nc | A_nexp nexp -> count_nexp_vars nexp | A_typ _ -> KBindings.empty in @@ -659,7 +660,7 @@ let rec doc_typ_fns ctx env = let m_pp = doc_nexp ctx ~skip_vars:kid_set m in let tpp, len_pp = string "mword " ^^ m_pp, string "length_mword" in let length_constraint_pp = - if KidSet.is_empty (KidSet.inter kid_set (nexp_frees m)) + if KidSet.is_empty (KidSet.inter kid_set (tyvars_of_nexp m)) then None else Some (separate space [len_pp; doc_var ctx var; string "=?"; doc_nexp ctx m]) in @@ -677,7 +678,7 @@ let rec doc_typ_fns ctx env = let m_pp = doc_nexp ctx ~skip_vars:kid_set m in let tpp, len_pp = string "vec" ^^ space ^^ typ elem_typ ^^ space ^^ m_pp, string "vec_length" in let length_constraint_pp = - if KidSet.is_empty (KidSet.inter kid_set (nexp_frees m)) + if KidSet.is_empty (KidSet.inter kid_set (tyvars_of_nexp m)) then None else Some (separate space [len_pp; doc_var ctx var; string "=?"; doc_nexp ctx m]) in @@ -1857,7 +1858,7 @@ let doc_exp, doc_let = calculating the instantiations. *) let vars_in_env n = let ekids = Env.get_typ_vars env in - let frees = nexp_frees n in + let frees = tyvars_of_nexp n in (not (KidSet.is_empty frees)) && KidSet.for_all (fun kid -> KBindings.mem kid ekids) frees in match (destruct_atom_nexp env typ_of_arg, destruct_atom_nexp env typ_from_fn) with diff --git a/test/typecheck/pass/type_if_then_else.sail b/test/typecheck/pass/type_if_then_else.sail new file mode 100644 index 000000000..6651db717 --- /dev/null +++ b/test/typecheck/pass/type_if_then_else.sail @@ -0,0 +1,9 @@ +default Order dec + +$include + +val test : forall 'n. int('n) -> int(if 'n == 0 then 32 else 64) + +function test(n) = { + if n == 0 then 32 else 64 +} diff --git a/test/typecheck/pass/type_if_then_else_alt.sail b/test/typecheck/pass/type_if_then_else_alt.sail new file mode 100644 index 000000000..dd79a4e41 --- /dev/null +++ b/test/typecheck/pass/type_if_then_else_alt.sail @@ -0,0 +1,9 @@ +default Order dec + +$include + +val test : forall 'n. int('n) -> {'m, ('n == 0 & 'm == 32) | ('n != 0 & 'm == 64). int('m)} + +function test(n) = { + if n == 0 then 32 else 64 +}