diff --git a/src/boot/lib/ast.ml b/src/boot/lib/ast.ml index c984f4dba..42bcfb376 100644 --- a/src/boot/lib/ast.ml +++ b/src/boot/lib/ast.ml @@ -378,7 +378,7 @@ and ty = (* Variant type *) | TyVariant of info * ustring list (* Type constructors *) - | TyCon of info * ustring + | TyCon of info * ustring * (bool * ustring list) option (* Type variables *) | TyVar of info * ustring (* Type application *) @@ -542,7 +542,7 @@ let ty_info = function | TyTensor (fi, _) | TyRecord (fi, _) | TyVariant (fi, _) - | TyCon (fi, _) + | TyCon (fi, _, _) | TyVar (fi, _) | TyApp (fi, _, _) -> fi diff --git a/src/boot/lib/mexpr.ml b/src/boot/lib/mexpr.ml index 6dfaa1f1b..12698c283 100644 --- a/src/boot/lib/mexpr.ml +++ b/src/boot/lib/mexpr.ml @@ -229,8 +229,12 @@ let getData = function | PTreeTy (TyVariant (fi, strs)) -> let len = List.length strs in (idTyVariant, [fi], [len], [], [], strs, [], [], [], []) - | PTreeTy (TyCon (fi, x)) -> - (idTyCon, [fi], [], [], [], [x], [], [], [], []) + | PTreeTy (TyCon (fi, x, None)) -> + (idTyCon, [fi], [], [], [], [x], [0], [], [], []) + | PTreeTy (TyCon (fi, x, Some (positive, cons))) -> + let pos = if positive then 1 else 2 in + let len = List.length cons + 1 in + (idTyCon, [fi], [len], [], [], x :: cons, [pos], [], [], []) | PTreeTy (TyVar (fi, x)) -> (idTyVar, [fi], [], [], [], [x], [], [], [], []) | PTreeTy (TyApp (fi, ty1, ty2)) -> diff --git a/src/boot/lib/mlang.ml b/src/boot/lib/mlang.ml index b99cb47ff..d006a247e 100644 --- a/src/boot/lib/mlang.ml +++ b/src/boot/lib/mlang.ml @@ -653,8 +653,8 @@ let rec desugar_ty env = function TyRecord (fi, Record.map (desugar_ty env) bindings) | TyVariant (fi, constrs) -> TyVariant (fi, constrs) - | TyCon (fi, id) -> - TyCon (fi, resolve_alias env id) + | TyCon (fi, id, data) -> + TyCon (fi, resolve_alias env id, data) | TyVar (fi, id) -> TyVar (fi, id) | TyApp (fi, lty, rty) -> @@ -862,7 +862,7 @@ let desugar_top (nss, langs, subs, syns, (stack : (tm -> tm) list)) = function let wrap_con ty_name (CDecl (fi, params, cname, ty)) tm = let app_param ty param = TyApp (fi, ty, TyVar (fi, param)) in let all_param param ty = TyAll (fi, param, ty) in - let con = List.fold_left app_param (TyCon (fi, ty_name)) params in + let con = List.fold_left app_param (TyCon (fi, ty_name, None)) params in TmConDef ( fi , mangle cname diff --git a/src/boot/lib/parser.mly b/src/boot/lib/parser.mly index b6e88f90d..3805073d6 100644 --- a/src/boot/lib/parser.mly +++ b/src/boot/lib/parser.mly @@ -596,17 +596,33 @@ ty_atom: { TyChar $1.i } | TSTRING { TySeq($1.i,TyChar $1.i) } - | type_ident - { TyCon($1.i,$1.v) } + | type_ident ty_data + { TyCon ($1.i, $1.v, $2)} | var_ident { TyVar($1.i,$1.v)} +ty_data: + | + { None } + | AND LBRACKET con_list RBRACKET + { Some (true, $3) } + | NOT LBRACKET con_list RBRACKET + { Some (false, $3) } + ty_list: | ty COMMA ty_list { $1 :: $3 } | ty { [$1] } +con_list: + | type_ident BAR con_list + { $1.v :: $3 } + | type_ident + { [$1.v] } + | + { [] } + label_tys: | label_ident COLON ty {[($1.v, $3)]} diff --git a/src/boot/lib/pprint.ml b/src/boot/lib/pprint.ml index 8090d86c4..f86fabd2e 100644 --- a/src/boot/lib/pprint.ml +++ b/src/boot/lib/pprint.ml @@ -204,7 +204,7 @@ let rec ustring_of_ty = function us "<>" | TyVariant _ -> failwith "Printing of non-empty variant types not yet supported" - | TyCon (_, x) -> + | TyCon (_, x, _) -> pprint_type_str x | TyVar (_, x) -> pprint_var_str x diff --git a/stdlib/mexpr/ast.mc b/stdlib/mexpr/ast.mc index 9a5e91f78..9954581f4 100644 --- a/stdlib/mexpr/ast.mc +++ b/stdlib/mexpr/ast.mc @@ -1414,6 +1414,32 @@ lang ConTypeAst = Ast | TyCon r -> r.info end +lang DataTypeAst = Ast + type DataRec = + {info : Info, + universe : Map Name (Set Name), + positive : Bool, + cons : Set Name} + + syn Type = + | TyData DataRec + + sem tyWithInfo (info : Info) = + | TyData t -> TyData {t with info = info} + + sem infoTy = + | TyData r -> r.info + + sem computeData : DataRec -> Map Name (Set Name) + sem computeData = + | r -> + if r.positive then + mapMap (setIntersect r.cons) r.universe + else + mapMap (lam x. setSubtract x r.cons) r.universe +end + + lang VarTypeAst = Ast syn Type = -- Rigid type variable @@ -1559,5 +1585,5 @@ lang MExprAst = -- Types UnknownTypeAst + BoolTypeAst + IntTypeAst + FloatTypeAst + CharTypeAst + FunTypeAst + SeqTypeAst + RecordTypeAst + VariantTypeAst + ConTypeAst + - VarTypeAst + AppTypeAst + TensorTypeAst + AllTypeAst + AliasTypeAst + DataTypeAst + VarTypeAst + AppTypeAst + TensorTypeAst + AllTypeAst + AliasTypeAst end diff --git a/stdlib/mexpr/boot-parser.mc b/stdlib/mexpr/boot-parser.mc index e67987aa2..7c15a2d6b 100644 --- a/stdlib/mexpr/boot-parser.mc +++ b/stdlib/mexpr/boot-parser.mc @@ -273,9 +273,22 @@ lang BootParser = MExprAst + ConstTransformer constrs = mapEmpty nameCmp} else error "Parsing of non-empty variant types not yet supported" | 209 /-TyCon-/ -> + let data = + let makeData = lam positive. + let cons = setOfSeq nameCmp (map (gname t) (range 1 (glistlen t 0) 1)) in + TyData { info = ginfo t 0, universe = mapEmpty nameCmp, + positive = positive, cons = cons } + in + switch gint t 0 + case 0 then TyUnknown { info = ginfo t 0 } + case 1 then makeData true + case 2 then makeData false + case _ then error "BootParser.matchTerm: Invalid data specifier for TyCon" + end + in TyCon {info = ginfo t 0, ident = gname t 0, - data = TyUnknown { info = ginfo t 0 }} + data = data} | 210 /-TyVar-/ -> TyVar {info = ginfo t 0, ident = gname t 0} diff --git a/stdlib/mexpr/cmp.mc b/stdlib/mexpr/cmp.mc index c69d223d3..480bfddd3 100644 --- a/stdlib/mexpr/cmp.mc +++ b/stdlib/mexpr/cmp.mc @@ -411,6 +411,12 @@ lang ConTypeCmp = Cmp + ConTypeAst else nameDiff end +lang DataTypeCmp = Cmp + DataTypeAst + sem cmpTypeH = + | (TyData l, TyData r) -> + mapCmp setCmp (computeData l) (computeData r) +end + lang VarTypeCmp = Cmp + VarTypeAst sem cmpTypeH = | (TyVar t1, TyVar t2) -> nameCmp t1.ident t2.ident @@ -472,7 +478,7 @@ lang MExprCmp = -- Types UnknownTypeCmp + BoolTypeCmp + IntTypeCmp + FloatTypeCmp + CharTypeCmp + FunTypeCmp + SeqTypeCmp + TensorTypeCmp + RecordTypeCmp + VariantTypeCmp + - ConTypeCmp + VarTypeCmp + AppTypeCmp + AllTypeCmp + AliasTypeCmp + ConTypeCmp + DataTypeCmp + VarTypeCmp + AppTypeCmp + AllTypeCmp + AliasTypeCmp end ----------- diff --git a/stdlib/mexpr/const-types.mc b/stdlib/mexpr/const-types.mc index 360870397..7489ded01 100644 --- a/stdlib/mexpr/const-types.mc +++ b/stdlib/mexpr/const-types.mc @@ -6,7 +6,6 @@ include "ast-builder.mc" let tysym_ = tycon_ "Symbol" let tyref_ = lam a. tyapp_ (tycon_ "Ref") a -let tymap_ = lam k. lam v. tyapp_ (tyapp_ (tycon_ "Map") k) v let tybootparsetree_ = tycon_ "BootParseTree" let tyvarseq_ = lam id. tyseq_ (tyvar_ id) diff --git a/stdlib/mexpr/eq.mc b/stdlib/mexpr/eq.mc index 5dd7b3709..56218c1db 100644 --- a/stdlib/mexpr/eq.mc +++ b/stdlib/mexpr/eq.mc @@ -637,6 +637,15 @@ lang ConTypeEq = Eq + ConTypeAst else None () end +lang DataTypeEq = Eq + DataTypeAst + sem eqTypeH (typeEnv : EqTypeEnv) (free : EqTypeFreeEnv) (lhs : Type) = + | rhs & TyData r -> + match unwrapType lhs with TyData l then + if mapEq setEq (computeData l) (computeData r) then Some free + else None () + else None () +end + lang VarTypeEq = Eq + VarTypeAst sem eqTypeH (typeEnv : EqTypeEnv) (free : EqTypeFreeEnv) (lhs : Type) = | TyVar r -> @@ -714,8 +723,8 @@ lang MExprEq = -- Types + UnknownTypeEq + BoolTypeEq + IntTypeEq + FloatTypeEq + CharTypeEq + - FunTypeEq + SeqTypeEq + RecordTypeEq + VariantTypeEq + ConTypeEq + VarTypeEq + - AllTypeEq + AppTypeEq + TensorTypeEq + AliasTypeEq + FunTypeEq + SeqTypeEq + RecordTypeEq + VariantTypeEq + ConTypeEq + DataTypeEq + + VarTypeEq + AllTypeEq + AppTypeEq + TensorTypeEq + AliasTypeEq end ----------- diff --git a/stdlib/mexpr/pprint.mc b/stdlib/mexpr/pprint.mc index 2fd4b6ebb..a87bf5092 100644 --- a/stdlib/mexpr/pprint.mc +++ b/stdlib/mexpr/pprint.mc @@ -1144,13 +1144,24 @@ lang VariantTypePrettyPrint = VariantTypeAst -- still use TyVariant in the AST and might get compilation errors for it. end -lang ConTypePrettyPrint = IdentifierPrettyPrint + ConTypeAst + UnknownTypeAst +lang ConTypePrettyPrint = IdentifierPrettyPrint + ConTypeAst + UnknownTypeAst + DataTypeAst sem getTypeStringCode (indent : Int) (env: PprintEnv) = | TyCon t -> match pprintTypeName env t.ident with (env, idstr) in - match t.data with TyUnknown {} then (env, idstr) else - match getTypeStringCode indent env t.data with (env, str) in - (env, join [str, ".", idstr]) + let d = unwrapType t.data in + match d with TyUnknown _ then (env, idstr) else + match getTypeStringCode indent env t.data with (env, datastr) in + match d with TyData _ then (env, concat idstr datastr) else + (env, join [idstr, "&", datastr]) +end + +lang DataTypePrettyPrint = IdentifierPrettyPrint + DataTypeAst + sem getTypeStringCode (indent : Int) (env: PprintEnv) = + | TyData t -> + let consstr = strJoin "|" (map nameGetStr (setToSeq t.cons)) in + let datastr = + join [if t.positive then "&" else "!", "{", consstr, "}"] + in (env, datastr) end lang VarTypePrettyPrint = IdentifierPrettyPrint + VarTypeAst @@ -1159,22 +1170,19 @@ lang VarTypePrettyPrint = IdentifierPrettyPrint + VarTypeAst pprintVarName env t.ident end -lang KindPrettyPrint = PrettyPrint + RecordTypeAst + KindAst +lang KindPrettyPrint = PrettyPrint + RecordTypeAst + DataTypeAst + KindAst sem getKindStringCode (indent : Int) (env : PprintEnv) = | Record r -> - let recty = TyRecord {info = NoInfo (), fields = r.fields} in - match getTypeStringCode indent env recty with (env, recstr) in - (env, join [init recstr, " ... ", [last recstr]]) + let tyrec = TyRecord {info = NoInfo (), fields = r.fields} in + getTypeStringCode indent env tyrec | Data r -> - let tstr = - mapFoldWithKey (lam strs. lam t. lam ks. - snoc strs (join [ nameGetStr t, "{" - , strJoin ", " (map nameGetStr (setToSeq ks)) - , "}" ]) - ) "" r.types in - (env, join ["<", strJoin ", " tstr, ">"]) - | Poly () -> (env, "*") - | Mono () -> (env, "o") + let consstr = + mapFoldWithKey (lam strs. lam. lam ks. + snoc strs (strJoin "|" (map nameGetStr (setToSeq ks)))) + [] r.types in + (env, join ["{", strJoin "|" consstr, "}"]) + | Poly () -> (env, "Poly") + | Mono () -> (env, "Mono") end lang AllTypePrettyPrint = IdentifierPrettyPrint + AllTypeAst + KindPrettyPrint @@ -1187,7 +1195,7 @@ lang AllTypePrettyPrint = IdentifierPrettyPrint + AllTypeAst + KindPrettyPrint match match t.kind with Mono () | Poly () then (env, "") else match getKindStringCode indent env t.kind with (env, kistr) in - (env, concat " :: " kistr) + (env, concat "::" kistr) with (env, kistr) in match getTypeStringCode indent env t.ty with (env, tystr) in (env, join ["all ", idstr, kistr, ". ", tystr]) @@ -1250,7 +1258,7 @@ lang MExprPrettyPrint = UnknownTypePrettyPrint + BoolTypePrettyPrint + IntTypePrettyPrint + FloatTypePrettyPrint + CharTypePrettyPrint + FunTypePrettyPrint + SeqTypePrettyPrint + RecordTypePrettyPrint + VariantTypePrettyPrint + - ConTypePrettyPrint + VarTypePrettyPrint + + ConTypePrettyPrint + DataTypePrettyPrint + VarTypePrettyPrint + AppTypePrettyPrint + TensorTypePrettyPrint + AllTypePrettyPrint + AliasTypePrettyPrint diff --git a/stdlib/mexpr/symbolize.mc b/stdlib/mexpr/symbolize.mc index 872c67edd..0c9b9108a 100644 --- a/stdlib/mexpr/symbolize.mc +++ b/stdlib/mexpr/symbolize.mc @@ -304,7 +304,20 @@ lang ConTypeSym = Sym + ConTypeAst allowFree = env.allowFree} env.tyConEnv t.ident in - TyCon {t with ident = ident} + TyCon {t with ident = ident, data = symbolizeType env t.data} +end + +lang DataTypeSym = Sym + DataTypeAst + sem symbolizeType env = + | TyData t -> + let cons = + setOfSeq nameCmp + (map (getSymbol {kind = "constructor", + info = [t.info], + allowFree = env.allowFree} + env.conEnv) + (setToSeq t.cons)) + in TyData {t with cons = cons} end lang VarTypeSym = Sym + VarTypeAst + UnknownTypeAst @@ -411,7 +424,7 @@ lang MExprSym = MatchSym + -- Non-default implementations (Types) - VariantTypeSym + ConTypeSym + VarTypeSym + AllTypeSym + + VariantTypeSym + ConTypeSym + DataTypeSym + VarTypeSym + AllTypeSym + -- Non-default implementations (Patterns) NamedPatSym + SeqEdgePatSym + DataPatSym + NotPatSym diff --git a/stdlib/mexpr/type-check.mc b/stdlib/mexpr/type-check.mc index d8cf26470..35fb97315 100644 --- a/stdlib/mexpr/type-check.mc +++ b/stdlib/mexpr/type-check.mc @@ -31,9 +31,11 @@ include "mexpr/value.mc" type TCEnv = { varEnv: Map Name Type, - conEnv: Map Name Type, + conEnv: Map Name (Level, Type), tyVarEnv: Map Name Level, tyConEnv: Map Name (Level, [Name], Type), + typeDeps : Map Name (Set Name), -- The set of type names recursively occuring in a type + conDeps : Map Name (Set Name), -- The set of constructors in scope for a type currentLvl: Level, disableRecordPolymorphism: Bool } @@ -43,10 +45,11 @@ let _tcEnvEmpty = { conEnv = mapEmpty nameCmp, tyVarEnv = mapEmpty nameCmp, tyConEnv = - mapFromSeq nameCmp - (map (lam t: (String, [String]). + mapFromSeq nameCmp + (map (lam t. (nameNoSym t.0, (0, map nameSym t.1, tyvariant_ []))) builtinTypes), - + typeDeps = mapEmpty nameCmp, + conDeps = mapEmpty nameCmp, currentLvl = 0, disableRecordPolymorphism = true } @@ -54,9 +57,6 @@ let _tcEnvEmpty = { let _insertVar = lam name. lam ty. lam env : TCEnv. {env with varEnv = mapInsert name ty env.varEnv} -let _insertCon = lam name. lam ty. lam env : TCEnv. - {env with conEnv = mapInsert name ty env.conEnv} - ---------------------- -- TYPE UNIFICATION -- ---------------------- @@ -130,7 +130,7 @@ lang TCUnify = Unify + AliasTypeAst + PrettyPrint + Cmp + MetaVarTypeCmp let msg = join [ "* Expected an expression of type: ", expected, "\n", - "* Found an expression of type: ", + "* Found an expression of type: ", found, "\n", aliases, "* When type checking the expression\n" @@ -155,7 +155,39 @@ lang VarTypeTCUnify = TCUnify + VarTypeAst else () end -lang MetaVarTypeTCUnify = TCUnify + MetaVarTypeUnify + UnifyRecords + RecordTypeAst +lang DataTypeTCUnify = TCUnify + DataTypeAst + KindAst + sem unifyCheckData + : Map Name (Level, Type) + -> Map Name (Level, [Name], Type) + -> [Info] + -> MetaVarRec + -> Map Name (Set Name) + -> () + sem unifyCheckData conEnv tyConEnv info tv = + | data -> + let mkMsg = lam sort. lam n. join [ + "* Encountered a ", sort, " escaping its scope: ", + nameGetStr n, "\n", + "* When type checking the expression\n" + ] in + iter + (lam tks. + if optionMapOr true (lam r. lti tv.level r.0) (mapLookup tks.0 tyConEnv) then + errorSingle info (mkMsg "type constructor" tks.0) + else + iter (lam k. + if optionMapOr true (lam r. lti tv.level r.0) (mapLookup k conEnv) then + errorSingle info (mkMsg "constructor" k) + else ()) + (setToSeq tks.1)) + (mapBindings data) + + sem unifyCheckBase env info boundVars tv = + | TyData t -> + unifyCheckData env.conEnv env.tyConEnv info tv (computeData t) +end + +lang MetaVarTypeTCUnify = TCUnify + MetaVarTypeUnify + UnifyRecords + RecordTypeAst + DataTypeTCUnify sem addKinds : Unifier () -> UnifyEnv -> (Kind, Kind) -> Kind sem addKinds u env = | (Record r1, Record r2) -> @@ -164,6 +196,7 @@ lang MetaVarTypeTCUnify = TCUnify + MetaVarTypeUnify + UnifyRecords + RecordType | (Data r1, Data r2) -> Data {r1 with types = mapUnionWith setUnion r1.types r2.types} | (Mono _ | Poly _, k & !(Mono _ | Poly _)) -> k + | (!(Mono _ | Poly _) & k, Mono _ | Poly _) -> k | (Poly _, k & (Poly _ | Mono _)) -> k | (Mono _, Poly _ | Mono _) -> Mono () | (k1, k2) -> u.err (Kinds (k1, k2)); error "impossible" @@ -183,9 +216,17 @@ lang MetaVarTypeTCUnify = TCUnify + MetaVarTypeUnify + UnifyRecords + RecordType | (TyMetaVar t1 & ty1, !TyMetaVar _ & ty2) -> match deref t1.contents with Unbound tv in unifyCheck tcenv info tv ty2; - (match (tv.kind, ty2) with (Record r1, TyRecord r2) then - unifyRecordsSubset u env r1.fields r2.fields - else match tv.kind with Record _ then u.err (Types (ty1, ty2)) else ()); + (switch (tv.kind, ty2) + case (Record r1, TyRecord r2) then unifyRecordsSubset u env r1.fields r2.fields + case (Data r1, TyData r2) then + let data = computeData r2 in + if mapAllWithKey (lam t. lam ks1. + optionMapOr false (setSubset ks1) (mapLookup t data)) r1.types + then () + else u.err (Types (ty1, ty2)) + case (Record _ | Data _, _) then u.err (Types (ty1, ty2)) + case _ then () + end); modref t1.contents (Link env.wrappedRhs) sem unifyCheckBase env info boundVars tv = @@ -202,8 +243,11 @@ lang MetaVarTypeTCUnify = TCUnify + MetaVarTypeUnify + UnifyRecords + RecordType let kind = match (tv.kind, r.kind) with (Mono _, Poly _) then Mono () else - sfold_Kind_Type - (lam. lam ty. unifyCheckType env info boundVars tv ty) () r.kind; + (match r.kind with Data d then + unifyCheckData env.conEnv env.tyConEnv info tv d.types + else + sfold_Kind_Type + (lam. lam ty. unifyCheckType env info boundVars tv ty) () r.kind); r.kind in let updated = Unbound {r with level = mini r.level tv.level, @@ -229,15 +273,15 @@ end lang ConTypeTCUnify = TCUnify + ConTypeAst sem unifyCheckBase env info boundVars tv = | TyCon t -> - match optionMap (lam r. lti tv.level r.0) (mapLookup t.ident env.tyConEnv) with - !Some false then + if optionMapOr true (lam r. lti tv.level r.0) (mapLookup t.ident env.tyConEnv) then let msg = join [ "* Encountered a type constructor escaping its scope: ", nameGetStr t.ident, "\n", "* When type checking the expression\n" ] in errorSingle info msg - else () + else + unifyCheckType env info boundVars tv t.data end ------------------------------------ @@ -349,18 +393,17 @@ end -- NOTE(aathn, 2023-05-10): In the future, this should be replaced -- with something which also performs a proper kind check. lang ResolveType = ConTypeAst + AppTypeAst + AliasTypeAst + VariantTypeAst + - UnknownTypeAst + VarTypeSubstitute + AppTypeGetArgs - sem resolveType : Info -> Map Name (Level, [Name], Type) -> Type -> Type - sem resolveType info tycons = + UnknownTypeAst + DataTypeAst + VarTypeSubstitute + AppTypeGetArgs + sem resolveType : Info -> TCEnv -> Type -> Type + sem resolveType info env = | (TyCon _ | TyApp _) & ty -> let mkAppTy = foldl (lam ty1. lam ty2. TyApp {info = mergeInfo (infoTy ty1) (infoTy ty2), lhs = ty1, rhs = ty2}) in match getTypeArgs ty with (constr, args) in - let args = map (resolveType info tycons) args in + let args = map (resolveType info env) args in match constr with (TyCon t) & conTy then - match mapLookup t.ident tycons with Some (_, params, def) then - let appTy = mkAppTy conTy args in + match mapLookup t.ident env.tyConEnv with Some (_, params, def) then match def with !TyVariant _ then -- It's an alias match (length params, length args) with (paramLen, argLen) in if eqi paramLen argLen then @@ -368,31 +411,40 @@ lang ResolveType = ConTypeAst + AppTypeAst + AliasTypeAst + VariantTypeAst + (mapEmpty nameCmp) params args in -- We assume def has already been resolved before being put into tycons - TyAlias {display = appTy, content = substituteVars subst def} + TyAlias {display = mkAppTy conTy args, content = substituteVars subst def} else errorSingle [infoTy ty] (join [ - "* Encountered a misformed type alias.\n", + "* Encountered a misformed type constructor or alias.\n", "* Type ", nameGetStr t.ident, " is declared to have ", int2string paramLen, " parameters.\n", "* Found ", int2string argLen, " arguments.\n", "* When checking the annotation" ]) else - appTy + match t.data with TyData d then + let tys = mapLookupOrElse (lam. setEmpty nameCmp) t.ident env.typeDeps in + let universe = + mapMapWithKey (lam s. lam. + match mapLookup s env.conDeps with Some cons in + cons) tys + in + mkAppTy (TyCon {t with data = TyData {d with universe = universe}}) args + else + mkAppTy conTy args else errorSingle [t.info] (join [ "* Encountered an unknown type constructor: ", nameGetStr t.ident, "\n", "* When checking the annotation" ]) else - mkAppTy (resolveType info tycons constr) args + mkAppTy (resolveType info env constr) args -- If we encounter a TyAlias, it means that the type was already processed by -- a previous call to typeCheck. | TyAlias t -> TyAlias t | ty -> - smap_Type_Type (resolveType info tycons) ty + smap_Type_Type (resolveType info env) ty end lang SubstituteUnknown = UnknownTypeAst + KindAst + AliasTypeAst @@ -493,7 +545,7 @@ end lang LamTypeCheck = TypeCheck + LamAst + ResolveType + SubstituteUnknown sem typeCheckExpr env = | TmLam t -> - let tyAnnot = resolveType t.info env.tyConEnv t.tyAnnot in + let tyAnnot = resolveType t.info env t.tyAnnot in let tyParam = substituteUnknown (Mono ()) env.currentLvl t.info tyAnnot in let body = typeCheckExpr (_insertVar t.ident tyParam env) t.body in let tyLam = ityarrow_ t.info tyParam (tyTm body) in @@ -518,7 +570,7 @@ lang LetTypeCheck = sem typeCheckExpr env = | TmLet t -> let newLvl = addi 1 env.currentLvl in - let tyAnnot = resolveType t.info env.tyConEnv t.tyAnnot in + let tyAnnot = resolveType t.info env t.tyAnnot in let tyBody = substituteUnknown (Poly ()) newLvl t.info tyAnnot in match if isValue (GVal ()) t.body then @@ -562,7 +614,7 @@ lang RecLetsTypeCheck = TypeCheck + RecLetsAst + LetTypeCheck + MetaVarDisableGe let newLvl = addi 1 env.currentLvl in -- First: Generate a new environment containing the recursive bindings let recLetEnvIteratee = lam acc. lam b: RecLetBinding. - let tyAnnot = resolveType t.info env.tyConEnv b.tyAnnot in + let tyAnnot = resolveType t.info env b.tyAnnot in let tyBody = substituteUnknown (Poly ()) newLvl t.info tyAnnot in let vars = if isValue (GVal ()) b.body then (stripTyAll tyBody).0 else [] in let newEnv = _insertVar b.ident tyBody acc.0 in @@ -630,12 +682,12 @@ lang MatchTypeCheck = TypeCheck + PatTypeCheck + MatchAst , pat = pat} end -lang ConstTypeCheck = TypeCheck + MExprConstType +lang ConstTypeCheck = TypeCheck + MExprConstType + SubstituteUnknown sem typeCheckExpr env = | TmConst t -> recursive let f = lam ty. smap_Type_Type f (tyWithInfo t.info ty) in - let ty = inst t.info env.currentLvl (f (tyConst t.val)) in - TmConst {t with ty = ty} + let ty = substituteUnknown (Poly ()) env.currentLvl t.info (f (tyConst t.val)) in + TmConst {t with ty = inst t.info env.currentLvl ty} end lang SeqTypeCheck = TypeCheck + SeqAst @@ -658,28 +710,32 @@ lang RecordTypeCheck = TypeCheck + RecordAst + RecordTypeAst let rec = typeCheckExpr env t.rec in let value = typeCheckExpr env t.value in let fields = mapInsert t.key (tyTm value) (mapEmpty cmpSID) in - unify env [infoTm rec] (newrowvar fields env.currentLvl (infoTm rec)) (tyTm rec); + unify env [infoTm rec] (newrecvar fields env.currentLvl (infoTm rec)) (tyTm rec); TmRecordUpdate {t with rec = rec, value = value, ty = tyTm rec} end lang TypeTypeCheck = TypeCheck + TypeAst + VariantTypeAst + ResolveType sem typeCheckExpr env = | TmType t -> - let tyIdent = resolveType t.info env.tyConEnv t.tyIdent in + let tyIdent = resolveType t.info env t.tyIdent in -- NOTE(aathn, 2023-05-08): Aliases are treated as the underlying -- type and do not need to be scope checked. let newLvl = match tyIdent with !TyVariant _ then addi 1 env.currentLvl else 0 in let newTyConEnv = mapInsert t.ident (newLvl, t.params, tyIdent) env.tyConEnv in + let newTypeDeps = mapInsert t.ident (setOfSeq nameCmp [t.ident]) env.typeDeps in + let newConDeps = mapInsert t.ident (setEmpty nameCmp) env.conDeps in let inexpr = typeCheckExpr {env with currentLvl = addi 1 env.currentLvl, - tyConEnv = newTyConEnv} t.inexpr in + tyConEnv = newTyConEnv, + typeDeps = newTypeDeps, + conDeps = newConDeps} t.inexpr in unify env [t.info, infoTm inexpr] (newpolyvar env.currentLvl t.info) (tyTm inexpr); TmType {t with tyIdent = tyIdent, inexpr = inexpr, ty = tyTm inexpr} end lang DataTypeCheck = TypeCheck + DataAst + FunTypeAst + ResolveType - sem _makeConstructorType : Info -> Name -> Type -> Type + sem _makeConstructorType : Info -> Name -> Type -> (Name, Set Name, Type) sem _makeConstructorType info ident = | ty -> let msg = lam. join [ @@ -688,38 +744,54 @@ lang DataTypeCheck = TypeCheck + DataAst + FunTypeAst + ResolveType "* right-hand side should refer to a constructor type.\n", "* When type checking the expression\n" ] in - recursive let substituteData = lam v. lam x. - switch x - case TyCon (t & {data = TyUnknown _}) then - TyCon { t with data = v } - case TyAlias t then - TyAlias { t with content = substituteData v t.content } - case _ then - smap_Type_Type (substituteData v) x - end - in match inspectType ty with TyArrow {to = to & (TyCon _ | TyApp _)} then - match getTypeArgs to with (TyCon t, _) then + match getTypeArgs to with (TyCon target, _) then + recursive let substituteData = lam v. lam acc. lam x. + switch x + case TyCon (t & {data = TyUnknown _}) then + (if nameEq t.ident target.ident then acc else setInsert t.ident acc, + TyCon { t with data = v }) + case TyAlias t then + match substituteData v acc t.content with (acc, content) in + (acc, TyAlias { t with content = content }) + case _ then + smapAccumL_Type_Type (substituteData v) acc x + end + in + let x = nameSym "x" in + match substituteData (TyVar {info = info, ident = x}) (setEmpty nameCmp) ty + with (tydeps, newTy) in let data = Data { - types = mapFromSeq nameCmp [ (t.ident, setOfSeq nameCmp [ ident ]) ] + types = mapFromSeq nameCmp [ (target.ident, setOfSeq nameCmp [ ident ]) ] } in - let x = nameSym "x" in - TyAll { info = info - , ident = x - , kind = data - , ty = substituteData (TyVar {info = info, ident = x}) ty } + (target.ident, + tydeps, + TyAll { info = info + , ident = x + , kind = data + , ty = newTy }) else errorSingle [info] (msg ()) else errorSingle [info] (msg ()) sem typeCheckExpr env = | TmConDef t -> - let tyIdent = resolveType t.info env.tyConEnv t.tyIdent in - let tyIdent = _makeConstructorType t.info t.ident tyIdent in - let inexpr = typeCheckExpr (_insertCon t.ident tyIdent env) t.inexpr in + let tyIdent = resolveType t.info env t.tyIdent in + match _makeConstructorType t.info t.ident tyIdent with (target, tydeps, tyIdent) in + let newLvl = addi 1 env.currentLvl in + let inexpr = + typeCheckExpr + {env with currentLvl = newLvl, + conEnv = mapInsert t.ident (newLvl, tyIdent) env.conEnv, + typeDeps = mapInsertWith setUnion target tydeps env.typeDeps, + conDeps = mapInsertWith setUnion target + (setOfSeq nameCmp [t.ident]) env.conDeps} + t.inexpr + in + unify env [t.info, infoTm inexpr] (newpolyvar env.currentLvl t.info) (tyTm inexpr); TmConDef {t with tyIdent = tyIdent, inexpr = inexpr, ty = tyTm inexpr} | TmConApp t -> let body = typeCheckExpr env t.body in - match mapLookup t.ident env.conEnv with Some lty then + match mapLookup t.ident env.conEnv with Some (_, lty) then match inst t.info env.currentLvl lty with TyArrow {from = from, to = to} in unify env [infoTm body] from (tyTm body); TmConApp {t with body = body, ty = to} @@ -771,7 +843,7 @@ end lang ExtTypeCheck = TypeCheck + ExtAst + ResolveType sem typeCheckExpr env = | TmExt t -> - let tyIdent = resolveType t.info env.tyConEnv t.tyIdent in + let tyIdent = resolveType t.info env t.tyIdent in let env = {env with varEnv = mapInsert t.ident tyIdent env.varEnv} in let inexpr = typeCheckExpr env t.inexpr in TmExt {t with tyIdent = tyIdent, inexpr = inexpr, ty = tyTm inexpr} @@ -827,14 +899,14 @@ lang RecordPatTypeCheck = PatTypeCheck + RecordPat | PatRecord t -> let typeCheckBinding = lam patEnv. lam. lam pat. typeCheckPat env patEnv pat in match mapMapAccum typeCheckBinding patEnv t.bindings with (patEnv, bindings) in - let ty = newrowvar (mapMap tyPat bindings) env.currentLvl t.info in + let ty = newrecvar (mapMap tyPat bindings) env.currentLvl t.info in (patEnv, PatRecord {t with bindings = bindings, ty = ty}) end lang DataPatTypeCheck = PatTypeCheck + DataPat + FunTypeAst + Generalize sem typeCheckPat env patEnv = | PatCon t -> - match mapLookup t.ident env.conEnv with Some ty then + match mapLookup t.ident env.conEnv with Some (_, ty) then match inst t.info env.currentLvl ty with TyArrow {from = from, to = to} in match typeCheckPat env patEnv t.subpat with (patEnv, subpat) in unify env [infoPat subpat] from (tyPat subpat); @@ -883,6 +955,7 @@ lang MExprTypeCheck = -- Type unification MExprUnify + VarTypeTCUnify + MetaVarTypeTCUnify + AllTypeTCUnify + ConTypeTCUnify + + DataTypeTCUnify + -- Type generalization MetaVarTypeGeneralize + VarTypeGeneralize + AllTypeGeneralize + @@ -1253,7 +1326,7 @@ let tests = [ (mapInsert (stringToSid "y") wb (mapEmpty cmpSID)) in - let r = newrowvar fields 0 (NoInfo ()) in + let r = newrecvar fields 0 (NoInfo ()) in tyarrows_ [r, wa, wb, r], env = []}, diff --git a/stdlib/mexpr/type.mc b/stdlib/mexpr/type.mc index 67d05a2e7..b44b6695e 100644 --- a/stdlib/mexpr/type.mc +++ b/stdlib/mexpr/type.mc @@ -100,7 +100,9 @@ lang MetaVarTypePrettyPrint = IdentifierPrettyPrint + KindPrettyPrint + MetaVarT switch t.kind case Poly () then (env, idstr) case Mono () then (env, concat "_" idstr) - case _ then getKindStringCode indent env t.kind + case _ then + match getKindStringCode indent env t.kind with (env, str) in + (env, join [init str, " ...", [last str]]) end case Link ty then getTypeStringCode indent env ty @@ -134,7 +136,7 @@ let newmonovar = use KindAst in newmetavar (Mono ()) let newpolyvar = use KindAst in newmetavar (Poly ()) -let newrowvar = use KindAst in +let newrecvar = use KindAst in lam fields. newmetavar (Record {fields = fields}) let newvar = newpolyvar diff --git a/stdlib/mexpr/unify.mc b/stdlib/mexpr/unify.mc index bb121c68d..63b6f2b5b 100644 --- a/stdlib/mexpr/unify.mc +++ b/stdlib/mexpr/unify.mc @@ -153,6 +153,14 @@ lang ConTypeUnify = Unify + ConTypeAst u.err (Types (ty1, ty2)) end +lang DataTypeUnify = Unify + DataTypeAst + sem unifyBase u env = + | (TyData t1 & ty1, TyData t2 & ty2) -> + if mapEq setEq (computeData t1) (computeData t2) then u.empty + else + u.err (Types (ty1, ty2)) +end + lang BoolTypeUnify = Unify + BoolTypeAst sem unifyBase u env = | (TyBool _, TyBool _) -> u.empty @@ -245,8 +253,8 @@ end lang MExprUnify = VarTypeUnify + MetaVarTypeUnify + FunTypeUnify + AppTypeUnify + AllTypeUnify + - ConTypeUnify + BoolTypeUnify + IntTypeUnify + FloatTypeUnify + CharTypeUnify + - SeqTypeUnify + TensorTypeUnify + RecordTypeUnify + ConTypeUnify + DataTypeUnify + BoolTypeUnify + IntTypeUnify + FloatTypeUnify + + CharTypeUnify + SeqTypeUnify + TensorTypeUnify + RecordTypeUnify end lang TestLang = UnifyPure + MExprUnify + MExprEq + MetaVarTypeEq end