diff --git a/lib/ast.ml b/lib/ast.ml index 9c5adcb..823a165 100644 --- a/lib/ast.ml +++ b/lib/ast.ml @@ -1,29 +1,8 @@ type loc = (Lexing.position * Lexing.position[@opaque]) [@@deriving show] type ident = loc * string - -type bop = - | Plus_i - | Mult_i - | Div_i - | Less_i - | Greater_i - | Less_eq_i - | Greater_eq_i - | Equal_i - | Minus_i - | Plus_f - | Mult_f - | Div_f - | Less_f - | Greater_f - | Less_eq_f - | Greater_eq_f - | Equal_f - | Minus_f - | And - | Or -[@@deriving show, sexp] -(* Eventually, this will be handled differently, hopefully not as hardcoded *) +type bop = Equal_i | And | Or [@@deriving show, sexp] +(* Equal_i is used in pattern matches, so we keep it even though it is also + defined as a builtin function *) type unop = Uminus_i | Uminus_f [@@deriving show, sexp] @@ -69,11 +48,17 @@ and expr = | Pipe_head of loc * argument * pipeable | Pipe_tail of loc * argument * pipeable | Ctor of loc * ident * expr option - | Match of - loc * decl_attr * expr * (loc * Path.t option * pattern * expr) list + | Match of loc * decl_attr * expr * (clause * expr) list | Local_use of loc * string * expr | Fmt of loc * expr list +and clause = { + cloc : loc; + cpath : Path.t option; + cpat : pattern; + guard : (loc * expr) option; +} + and pipeable = Pip_expr of expr and pattern = diff --git a/lib/codegen/codegen.ml b/lib/codegen/codegen.ml index 098400c..038e2d4 100644 --- a/lib/codegen/codegen.ml +++ b/lib/codegen/codegen.ml @@ -346,56 +346,9 @@ end = struct in let open Llvm in match bop with - | Plus_i -> - { value = bld build_add "add"; typ = Tint; lltyp = int_t; kind = Imm } - | Minus_i -> - { value = bld build_sub "sub"; typ = Tint; lltyp = int_t; kind = Imm } - | Mult_i -> - { value = bld build_mul "mul"; typ = Tint; lltyp = int_t; kind = Imm } - | Div_i -> - { value = bld build_sdiv "div"; typ = Tint; lltyp = int_t; kind = Imm } - | Less_i -> - let value = bld (build_icmp Icmp.Slt) "lt" in - { value; typ = Tbool; lltyp = bool_t; kind = Imm } - | Greater_i -> - let value = bld (build_icmp Icmp.Sgt) "gt" in - { value; typ = Tbool; lltyp = bool_t; kind = Imm } - | Less_eq_i -> - let value = bld (build_icmp Icmp.Sle) "le" in - { value; typ = Tbool; lltyp = bool_t; kind = Imm } - | Greater_eq_i -> - let value = bld (build_icmp Icmp.Sge) "ge" in - { value; typ = Tbool; lltyp = bool_t; kind = Imm } | Equal_i -> let value = bld (build_icmp Icmp.Eq) "eq" in { value; typ = Tbool; lltyp = bool_t; kind = Imm } - | Plus_f -> - let value = bld build_fadd "add" in - { value; typ = Tfloat; lltyp = float_t; kind = Imm } - | Minus_f -> - let value = bld build_fsub "sub" in - { value; typ = Tfloat; lltyp = float_t; kind = Imm } - | Mult_f -> - let value = bld build_fmul "mul" in - { value; typ = Tfloat; lltyp = float_t; kind = Imm } - | Div_f -> - let value = bld build_fdiv "div" in - { value; typ = Tfloat; lltyp = float_t; kind = Imm } - | Less_f -> - let value = bld (build_fcmp Fcmp.Olt) "lt" in - { value; typ = Tbool; lltyp = bool_t; kind = Imm } - | Greater_f -> - let value = bld (build_fcmp Fcmp.Ogt) "gt" in - { value; typ = Tbool; lltyp = bool_t; kind = Imm } - | Less_eq_f -> - let value = bld (build_fcmp Fcmp.Ole) "le" in - { value; typ = Tbool; lltyp = bool_t; kind = Imm } - | Greater_eq_f -> - let value = bld (build_fcmp Fcmp.Oge) "ge" in - { value; typ = Tbool; lltyp = bool_t; kind = Imm } - | Equal_f -> - let value = bld (build_fcmp Fcmp.Oeq) "eq" in - { value; typ = Tbool; lltyp = bool_t; kind = Imm } | And -> let cond1 = gen e1 |> bring_default in diff --git a/lib/codegen/helpers.ml b/lib/codegen/helpers.ml index 9fab1cc..cc9aba7 100644 --- a/lib/codegen/helpers.ml +++ b/lib/codegen/helpers.ml @@ -781,8 +781,16 @@ struct ignore (List.fold_left f start_index params) let var_index var = - let tagptr = Llvm.build_struct_gep var.lltyp var.value 0 "tag" builder in - let value = Llvm.build_load i32_t tagptr "index" builder in + let value = + match var.kind with + | Const_ptr | Ptr -> + let tagptr = + Llvm.build_struct_gep var.lltyp var.value 0 "tag" builder + in + Llvm.build_load i32_t tagptr "index" builder + | Const -> Llvm.(const_extractelement var.value (const_int i32_t 0)) + | Imm -> failwith "Did not expect Imm in var_index" + in { value; typ = Ti32; lltyp = i32_t; kind = Imm } let var_data var typ = diff --git a/lib/parser.mly b/lib/parser.mly index 7c468df..2734839 100644 --- a/lib/parser.mly +++ b/lib/parser.mly @@ -362,7 +362,11 @@ clause_path: |> Option.get } clause: - | path = option(clause_path); pattern = match_pattern; Colon; expr = expr { $loc, path, pattern, expr } + | cpath = option(clause_path); cpat = match_pattern; guard = option(guard); Colon; expr = expr + { { cloc = $loc; cpath; cpat; guard }, expr } + +guard: + | And; expr = expr { $loc, expr } clauses: | clause = clause { [ clause ] } diff --git a/lib/syntax_errors.messages b/lib/syntax_errors.messages index 88d8020..997f764 100644 --- a/lib/syntax_errors.messages +++ b/lib/syntax_errors.messages @@ -1010,6 +1010,14 @@ prog: Match False Lcurly Ident Hbar Ident Rpar +prog: Match False Lcurly Ident And With + + + +prog: Match False Lcurly Ident And False Wildcard + + + prog: Match False Lcurly Ident Colon With diff --git a/lib/typing/patternmatching.ml b/lib/typing/patternmatching.ml index 8224750..103e052 100644 --- a/lib/typing/patternmatching.ml +++ b/lib/typing/patternmatching.ml @@ -44,7 +44,7 @@ module type S = sig Ast.loc -> Ast.decl_attr -> Ast.expr -> - (Ast.loc * Path.t option * Ast.pattern * Ast.expr) list -> + (Ast.clause * Ast.expr) list -> Typed_tree.typed_expr val pattern_id : @@ -181,11 +181,15 @@ let get_variant env loc (_, name) annot = let msg = "Unbound constructor " ^ ctor_name name in raise (Error (loc, msg))) +type guard = (Ast.loc * Ast.expr) option +type attr = Ast.decl_attr = Dmut | Dmove | Dnorm | Dset + type pattern_data = { loc : Ast.loc; ret_expr : Ast.expr; row : int; ret_env : Env.t; + guard : guard; } type typed_pattern = { ptyp : typ; pat : tpat } @@ -193,9 +197,9 @@ and pathed_pattern = int list * typed_pattern and tpat = | Tp_ctor of Ast.loc * ctor_param - | Tp_var of Ast.loc * string * Ast.decl_attr + | Tp_var of Ast.loc * string * attr | Tp_wildcard of Ast.loc - | Tp_record of Ast.loc * record_field list * Ast.decl_attr + | Tp_record of Ast.loc * record_field list * attr | Tp_int of Ast.loc * int | Tp_char of Ast.loc * char @@ -221,6 +225,11 @@ let loc_of_pat = function loc module Tup = struct + (* Here, we decide which part of which clause to check first. We compile the + pattern matches assuming no duplicate clauses. If there are no duplicates, + there is a clear path to choose, see + https://julesjacobs.com/notes/patternmatching/patternmatching.pdf *) + type payload = { path : int list; (* Records need a path instead of just a column. {:a} in 1st column might be [0;0] *) @@ -333,7 +342,7 @@ module Exhaustiveness = struct module Set = Set.Make (String) type wip_kind = New_column | Specialization - type exhaustive = Exh | Wip of wip_kind * typed_pattern list list + type exhaustive = Exh | Wip of wip_kind * (typed_pattern list * guard) list type ctorset = Ctors of ctor list | Inf | Record of field list let ctorset_of_variant loc env typ = @@ -380,8 +389,8 @@ module Exhaustiveness = struct (** Check if ctorset is complete or some ctor is missing. Might also be infinite *) let sig_complete env fstcl patterns = match List.(hd patterns) with - | [] -> Empty - | p :: _ -> ( + | [], _ -> Empty + | p :: _, _ -> ( let typ = p.ptyp and loc = loc_of_pat p.pat in match ctorset_of_variant loc env typ with | Ctors ctors -> @@ -393,12 +402,14 @@ module Exhaustiveness = struct (* Special case if the last case is a wildcard. Here, we might have a complete ctor set before and the wildcard is redundant *) | [] -> last acc - | [ { pat = Tp_wildcard loc; _ } :: _ ] -> lwild loc acc + | [ ({ pat = Tp_wildcard loc; _ } :: _, _) ] -> lwild loc acc | hd :: tl -> fold f lwild last (f acc hd) tl in let f set = function - | { pat = Tp_ctor (_, p); _ } :: _ -> - Set.remove p.ctname set (* TODO wildcard *) + | { pat = Tp_ctor (_, p); _ } :: _, guard -> ( + match guard with + | Some _ -> set + | None -> Set.remove p.ctname set (* TODO wildcard *)) | _ -> set in let last set = @@ -411,7 +422,9 @@ module Exhaustiveness = struct (* The last row is a wildcard, but all ctors are there before. Might be a redundant case (see "redundant_all_cases" test case) *) Maybe_red (loc, ctors) - else f set [ { pat = Tp_wildcard loc; ptyp = typ } ] |> last + else + f set ([ { pat = Tp_wildcard loc; ptyp = typ } ], None) + |> last else fun _ set -> last set in @@ -426,16 +439,16 @@ module Exhaustiveness = struct let patterns = List.filter_map (function - | { pat = Tp_ctor _ | Tp_int _ | Tp_char _; _ } :: _ -> + | { pat = Tp_ctor _ | Tp_int _ | Tp_char _; _ } :: _, _ -> (* Drop row *) rows_empty := false; None - | { pat = Tp_wildcard _ | Tp_var _ | Tp_record _; _ } :: tl -> + | { pat = Tp_wildcard _ | Tp_var _ | Tp_record _; _ } :: tl, guard -> (* Discard head element *) new_col := true; rows_empty := false; - Some tl - | [] -> (* Empty row *) Some []) + Some (tl, guard) + | [], guard -> (* Empty row *) Some ([], guard)) patterns in @@ -455,7 +468,14 @@ module Exhaustiveness = struct let patterns = List.filter_map (function - | { pat = Tp_ctor (_, param); _ } :: tl + | p :: tl, (Some _ as guard) -> + (* Pattern guard *) + rows_empty := false; + new_col := true; + let loc = loc_of_pat p.pat in + (* Each pattern guard is identified by the index of its row TODO *) + Some ({ pat = Tp_int (loc, 0); ptyp = tint } :: p :: tl, guard) + | { pat = Tp_ctor (_, param); _ } :: tl, guard when String.equal param.ctname case -> rows_empty := false; let lst = @@ -465,12 +485,12 @@ module Exhaustiveness = struct tl | lst -> lst @ tl in - Some lst - | { pat = Tp_ctor _; _ } :: _ -> + Some (lst, guard) + | { pat = Tp_ctor _; _ } :: _, _ -> (* Drop row *) rows_empty := false; None - | { pat; ptyp } :: tl -> + | { pat; ptyp } :: tl, guard -> rows_empty := false; let lst = match num_args with @@ -481,8 +501,8 @@ module Exhaustiveness = struct let loc = loc_of_pat pat in to_n_list { pat = Tp_wildcard loc; ptyp } num_args [] @ tl in - Some lst - | [] -> (* Empty row *) Some []) + Some (lst, guard) + | [], guard -> (* Empty row *) Some ([], guard)) patterns in @@ -495,8 +515,8 @@ module Exhaustiveness = struct | Specialization, Specialization -> (Specialization, str) (* We add an extra redundancy check for first column *) - let rec is_exhaustive env fstcl patterns : - (unit, wip_kind * string list) result = + let rec is_exhaustive env fstcl (patterns : (typed_pattern list * guard) list) + : (unit, wip_kind * string list) result = match patterns with | [] -> Error (Specialization, []) | patterns -> ( @@ -546,16 +566,22 @@ module Exhaustiveness = struct let kind, strs = Result.get_error err in let strs = match kind with - | New_column -> - List.map - (fun str -> - Printf.sprintf "%s, %s" (ctor_name ctor) (ctor_name str)) - strs - | Specialization -> - List.map - (fun str -> - Printf.sprintf "%s(%s)" (ctor_name ctor) (ctor_name str)) - strs + | New_column -> ( + match strs with + | [] -> [ ctor_name ctor ] + | strs -> + List.map + (fun str -> + Printf.sprintf "%s, %s" (ctor_name ctor) (ctor_name str)) + strs) + | Specialization -> ( + match strs with + | [] -> [ ctor_name ctor ] + | strs -> + List.map + (fun str -> + Printf.sprintf "%s(%s)" (ctor_name ctor) (ctor_name str)) + strs) in Error (kind, strs) @@ -576,8 +602,8 @@ module Exhaustiveness = struct let wc_field loc ptyp = { pat = Tp_wildcard loc; ptyp } in let f = function - | [] -> failwith "Internal Error: There are so empty records" - | { pat = Tp_record (_, fields, _); _ } :: tl -> + | [], _ -> failwith "Internal Error: There are so empty records" + | { pat = Tp_record (_, fields, _); _ } :: tl, guard -> let fields = List.map (fun f -> @@ -586,12 +612,12 @@ module Exhaustiveness = struct | None -> wc_field f.floc f.iftyp) fields in - fields @ tl - | p :: tl -> + (fields @ tl, guard) + | p :: tl, guard -> let fields = List.map (fun f -> wc_field (loc_of_pat p.pat) f.ftyp) fields in - fields @ tl + (fields @ tl, guard) in is_exhaustive env false (List.map f patterns) end @@ -926,9 +952,9 @@ module Make (C : Core) (R : Recs) = struct let used_rows = ref Row_set.empty in let typed_cases = List.map - (fun (_, local_open, p, ret_expr) -> + (fun ({ Ast.cloc = _; cpath; cpat = p; guard }, ret_expr) -> let env = - match local_open with + match cpath with | Some path -> Env.use_module env loc path | None -> env in @@ -938,7 +964,8 @@ module Make (C : Core) (R : Recs) = struct let loc = loc_of_pat (snd pat).pat in used_rows := Row_set.add Row.{ loc; cnt = !exp_rows } !used_rows; - ([ pat ], { loc; ret_expr; row = !exp_rows; ret_env = env }))) + ( [ pat ], + { loc; ret_expr; row = !exp_rows; ret_env = env; guard } ))) cases in let typed_cases = List.concat typed_cases in @@ -949,7 +976,9 @@ module Make (C : Core) (R : Recs) = struct (* Check for exhaustiveness *) (let patterns = - List.map (fun p -> List.map (fun p -> snd p) (fst p)) typed_cases + List.map + (fun (pl, pd) -> (List.map (fun p -> snd p) pl, pd.guard)) + typed_cases in match Exhaustiveness.is_exhaustive env true patterns with | Ok () -> () @@ -970,13 +999,11 @@ module Make (C : Core) (R : Recs) = struct { cont with expr } and compile_matches env all_loc used_rows cases ret_typ rmut pass = - (* We build the decision tree here. - [match_cases] splits cases into ones that match and ones that don't. - [compile_matches] then generates the tree for the cases. - This boils down to a chain of if-then-else exprs. A heuristic for - choosing the ctor to check first in a case is not needed right now, - since we have neither tuples nor literals in matches, but it will - be part of [compile_matches] eventually *) + (* We build the decision tree here. [match_cases] splits cases into ones + that match and ones that don't. [compile_matches] then generates the tree + for the cases. This boils down to a chain of if-then-else exprs. A + heuristic for choosing the ctor to check first in a case is part of + [Tup.choose_next]. *) (* Magic value, see above *) let expr i = make_var_expr_fn env all_loc i in @@ -998,16 +1025,34 @@ module Make (C : Core) (R : Recs) = struct match cases with | hd :: tl -> ( match Tup.choose_next hd tl with - | Bare d -> + | Bare d -> ( (* Mark row as used *) used_rows := Row_set.remove { cnt = d.row; loc = d.loc } !used_rows; - let ret = convert d.ret_env d.ret_expr in - - unify - (d.loc, "Match expression does not match:") - ret_typ ret.typ env; - ret + (* If there is a pattern guard, there are multiple paths with the + same checks we have just done. The already checked ctors are part + of [tl]. If we don't match the guard, we can continue with the + unguarded cases in [tl] (or with additional guards). *) + match d.guard with + | Some (loc, guard) -> + let then_ = convert d.ret_env d.ret_expr in + unify + (d.loc, "Match expression does not match:") + ret_typ then_.typ env; + + let else_ = + compile_matches env d.loc used_rows tl ret_typ rmut pass + in + let cond = convert d.ret_env guard in + unify (loc, "In pattern guard") tbool cond.typ d.ret_env; + let expr = If (cond, None, then_, else_) in + { typ = ret_typ; expr; attr = no_attr; loc } + | None -> + let ret = convert d.ret_env d.ret_expr in + unify + (d.loc, "Match expression does not match:") + ret_typ ret.typ env; + ret) | Var ({ path; loc; d; patterns; pltyp }, id, dattr) -> (* Bind the variable *) let mut = mut_of_pattr dattr in @@ -1154,9 +1199,20 @@ module Make (C : Core) (R : Recs) = struct ret fields in expr) - | [] -> failwith "Internal Error: Empty match" + | [] -> + (* This can happen if there are pattern guards and the fallback case is + missing. Generate some expression and let it fail later in the + redundancy check. *) + { typ = ret_typ; expr = Const Unit; attr = no_attr; loc = all_loc } and match_cases (i, case) cases if_ else_ = + (* The result of match cases are two pattern lists. The first one will + contain remaining clauses if the pattern was matched. The remaining + clauses will have specialized the current check so that we don't have to + check the same pattern multiple times. E.g. if we match Some(1) in one + clause and Some(2) in the next, the remaining list will now contain only + Int(1) and Int(2). The second pattern list will contain the "else" case + where we haven't matched Some (in this example). *) match cases with | (clauses, d) :: tl -> ( match List.assoc_opt i clauses with @@ -1257,7 +1313,7 @@ module Make (C : Core) (R : Recs) = struct let bind_pattern env loc i p = let typed = type_pattern env ([ i ], p) in let pts = List.map snd typed in - (match Exhaustiveness.is_exhaustive env true [ pts ] with + (match Exhaustiveness.is_exhaustive env true [ (pts, None) ] with | Ok () -> () | Error (_, cases) -> let msg = diff --git a/lib/typing/patternmatching.mli b/lib/typing/patternmatching.mli index 257145c..0571a75 100644 --- a/lib/typing/patternmatching.mli +++ b/lib/typing/patternmatching.mli @@ -39,7 +39,7 @@ module type S = sig Ast.loc -> Ast.decl_attr -> Ast.expr -> - (Ast.loc * Path.t option * Ast.pattern * Ast.expr) list -> + (Ast.clause * Ast.expr) list -> Typed_tree.typed_expr val pattern_id : int -> Ast.pattern -> string * Ast.loc * bool * Ast.decl_attr diff --git a/lib/typing/typing.ml b/lib/typing/typing.ml index 74de26c..1bc5014 100644 --- a/lib/typing/typing.ml +++ b/lib/typing/typing.ml @@ -109,27 +109,7 @@ let check_unused env unused unmutated = in List.iter err unused -let string_of_bop = function - | Ast.Plus_i -> "+" - | Mult_i -> "*" - | Div_i -> "/" - | Less_i -> "<" - | Greater_i -> ">" - | Less_eq_i -> "<=" - | Greater_eq_i -> ">=" - | Equal_i -> "" - | Minus_i -> "-" - | Ast.Plus_f -> "+." - | Mult_f -> "*." - | Div_f -> ">." - | Less_f -> "<." - | Greater_f -> ">." - | Less_eq_f -> "<=." - | Greater_eq_f -> ">=." - | Equal_f -> "=." - | Minus_f -> "-." - | And -> "and" - | Or -> "or" +let string_of_bop = function Ast.Equal_i -> "==" | And -> "and" | Or -> "or" let typeof_annot ?(typedef = false) ?(param = false) env loc annot = let fn_kind = if param then Closure [] else Simple in @@ -1048,12 +1028,7 @@ end = struct let typ, (t1, t2, const) = match bop with - | Ast.Plus_i | Mult_i | Minus_i | Div_i -> (tint, check tint) - | Less_i | Equal_i | Greater_i | Less_eq_i | Greater_eq_i -> - (tbool, check tint) - | Plus_f | Mult_f | Minus_f | Div_f -> (tfloat, check tfloat) - | Less_f | Equal_f | Greater_f | Less_eq_f | Greater_eq_f -> - (tbool, check tfloat) + | Equal_i -> (tbool, check tint) | And | Or -> (tbool, check tbool) in { typ; expr = Bop (bop, t1, t2); attr = { no_attr with const }; loc } diff --git a/test/typing.ml b/test/typing.ml index b1f512d..ab639f9 100644 --- a/test/typing.ml +++ b/test/typing.ml @@ -614,7 +614,7 @@ match Some(10) {Some(1): 1 | Some(10): 10 | Some(_): 0 | None: -1} |} let test_match_int_wildcard_missing () = - test_exn "Pattern match is not exhaustive. Missing cases: " + test_exn "Pattern match is not exhaustive. Missing cases: Some" {|type option['a] = None | Some('a) match Some(10) {Some(1): 1 | Some(10): 10 | None: -1}|} @@ -639,7 +639,51 @@ let test_match_or_missing_var () = let test_match_or_redundant () = test_exn "Pattern match case is redundant" - "match (1, 2) {(a, 1) | (a, 2) | (a, 1): a | _: -1}" + "match (1, 2) { (a, 1) | (a, 2) | (a, 1): a | _: {-1}}" + +let test_match_guard_positive () = + test "unit" + {|type option['a] = None | Some('a) +match Some(0) { + Some(_) and true: { () } + Some(_): { () } + None: { () } +}|} + +let test_match_guard_after () = + test_exn "Pattern match case is redundant" + {|type option['a] = None | Some('a) +match Some(0) { + Some(_): { () } + Some(_) and true: { () } + None: { () } +}|} + +let test_match_guard_missing () = + test_exn "Pattern match is not exhaustive. Missing cases: Some" + {|type option['a] = None | Some('a) +match Some(0) { + Some(_) and true: { () } + None: { () } +}|} + +let test_match_guard_missing_spec () = + test_exn "Pattern match is not exhaustive. Missing cases: Some" + {|type option['a] = None | Some('a) +match Some(0) { + Some(_) and true: { () } + Some(1): { () } + None: { () } +}|} + +let test_match_guard_dont_leak_vars () = + test_exn "No var named a" + {|type option['a] = None | Some('a) +match Some(0) { + Some(a) and true: { ignore(a) } + Some(_): a + None: { println("none") } +}|} let test_multi_record2 () = test "foo[int, bool]" @@ -1644,6 +1688,11 @@ let () = case "or" test_match_or; case "or missing var" test_match_or_missing_var; case "or redundant" test_match_or_redundant; + case "guard positive" test_match_guard_positive; + case "guard after" test_match_guard_after; + case "guard missing" test_match_guard_missing; + case "guard missing spec" test_match_guard_missing_spec; + case "guard no var leak" test_match_guard_dont_leak_vars; ] ); ( "multi params", [