Skip to content

Commit

Permalink
Add start on constructor type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
aathn committed Nov 9, 2023
1 parent 3bab896 commit 91294a2
Show file tree
Hide file tree
Showing 15 changed files with 150 additions and 100 deletions.
6 changes: 3 additions & 3 deletions stdlib/cuda/wrapper.mc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ lang CudaCWrapperBase = PMExprCWrapper + CudaAst + MExprAst + CudaCompile
let fields = mapFromSeq cmpSID fieldsSeq in
let ty = TyRecord {t with fields = fields} in
optionMap
(lam id. TyCon {ident = id, info = t.info})
(lam id. nitycon_ id t.info)
(mapLookup ty cenv.revTypeEnv)
else None ()
| ty & (TySeq {ty = elemTy} | TyTensor {ty = elemTy}) ->
Expand All @@ -104,12 +104,12 @@ lang CudaCWrapperBase = PMExprCWrapper + CudaAst + MExprAst + CudaCompile
else never in
match env with CudaTargetEnv cenv in
match mapLookup ty cenv.revTypeEnv with Some id then
Some (TyCon {ident = id, info = infoTy ty})
Some (nitycon_ id (infoTy ty))
else None ()
| ty & (TyVariant _) ->
match env with CudaTargetEnv cenv in
match mapLookup ty cenv.revTypeEnv with Some id then
Some (TyCon {ident = id, info = infoTy ty})
Some (nitycon_ id (infoTy ty))
else None ()
| TyAlias t -> lookupTypeIdent env t.content
| ty -> Some ty
Expand Down
12 changes: 9 additions & 3 deletions stdlib/mexpr/ast-builder.mc
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,15 @@ let tyalias_ = use AliasTypeAst in
lam display. lam content.
TyAlias {display = display, content = content}

let ntycon_ = use ConTypeAst in
lam n.
TyCon {ident = n, info = NoInfo ()}
let nsitycon_ = use ConTypeAst in
lam n. lam d. lam i.
TyCon {ident = n, data = d, info = i}

let nitycon_ = lam n. lam i.
nsitycon_ n tyunknown_ i

let ntycon_ = lam n.
nitycon_ n (NoInfo ())

let tycon_ = lam s.
ntycon_ (nameNoSym s)
Expand Down
16 changes: 12 additions & 4 deletions stdlib/mexpr/ast.mc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ include "name.mc"
include "string.mc"
include "stringid.mc"
include "map.mc"
include "set.mc"

-----------
-- TERMS --
Expand Down Expand Up @@ -1398,11 +1399,17 @@ end
lang ConTypeAst = Ast
syn Type =
| TyCon {info : Info,
ident : Name}
ident : Name,
data : Type}

sem tyWithInfo (info : Info) =
| TyCon t -> TyCon {t with info = info}

sem smapAccumL_Type_Type (f : acc -> Type -> (acc, Type)) (acc : acc) =
| TyCon t ->
match f acc t.data with (acc, data) in
(acc, TyCon {t with data = data})

sem infoTy =
| TyCon r -> r.info
end
Expand All @@ -1424,13 +1431,14 @@ lang KindAst
syn Kind =
| Poly ()
| Mono ()
| Row {fields : Map SID Type}
| Record {fields : Map SID Type}
| Data {types : Map Name (Set Name)}

sem smapAccumL_Kind_Type : all acc. (acc -> Type -> (acc, Type)) -> acc -> Kind -> (acc, Kind)
sem smapAccumL_Kind_Type (f : acc -> Type -> (acc, Type)) (acc : acc) =
| Row r ->
| Record r ->
match mapMapAccum (lam acc. lam. lam e. f acc e) acc r.fields with (acc, flds) in
(acc, Row {r with fields = flds})
(acc, Record {r with fields = flds})
| s ->
(acc, s)

Expand Down
3 changes: 2 additions & 1 deletion stdlib/mexpr/boot-parser.mc
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ lang BootParser = MExprAst + ConstTransformer
else error "Parsing of non-empty variant types not yet supported"
| 209 /-TyCon-/ ->
TyCon {info = ginfo t 0,
ident = gname t 0}
ident = gname t 0,
data = TyUnknown { info = ginfo t 0 }}
| 210 /-TyVar-/ ->
TyVar {info = ginfo t 0,
ident = gname t 0}
Expand Down
9 changes: 7 additions & 2 deletions stdlib/mexpr/cmp.mc
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,10 @@ end

lang ConTypeCmp = Cmp + ConTypeAst
sem cmpTypeH =
| (TyCon t1, TyCon t2) -> nameCmp t1.ident t2.ident
| (TyCon t1, TyCon t2) ->
let nameDiff = nameCmp t1.ident t2.ident in
if eqi nameDiff 0 then cmpType t1.data t2.data
else nameDiff
end

lang VarTypeCmp = Cmp + VarTypeAst
Expand All @@ -415,8 +418,10 @@ end

lang KindCmp = Cmp + KindAst
sem cmpKind =
| (Row l, Row r) ->
| (Record l, Record r) ->
mapCmp cmpType l.fields r.fields
| (Data l, Data r) ->
mapCmp setCmp l.types r.types
| (lhs, rhs) ->
subi (constructorTag lhs) (constructorTag rhs)
end
Expand Down
8 changes: 6 additions & 2 deletions stdlib/mexpr/eq.mc
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,8 @@ lang ConTypeEq = Eq + ConTypeAst
sem eqTypeH (typeEnv : EqTypeEnv) (free : EqTypeFreeEnv) (lhs : Type) =
| rhs & TyCon r ->
match unwrapType lhs with TyCon l then
if nameEq l.ident r.ident then Some free else None ()
if nameEq l.ident r.ident then eqTypeH typeEnv free l.data r.data
else None ()
else None ()
end

Expand All @@ -648,7 +649,7 @@ end

lang KindEq = Eq + KindAst
sem eqKind (typeEnv : EqTypeEnv) (free : EqTypeFreeEnv) =
| (Row l, Row r) ->
| (Record l, Record r) ->
if eqi (mapSize l.fields) (mapSize r.fields) then
mapFoldlOption
(lam free. lam k1. lam v1.
Expand All @@ -657,6 +658,9 @@ lang KindEq = Eq + KindAst
else None ())
free l.fields
else None ()
| (Data l, Data r) ->
if mapEq setEq l.types r.types then Some free
else None ()
| (lhs, rhs) ->
if eqi (constructorTag lhs) (constructorTag rhs) then Some free
else None ()
Expand Down
2 changes: 1 addition & 1 deletion stdlib/mexpr/monomorphize.mc
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ lang MonomorphizeApply = MonomorphizeInstantiate + MonomorphizeResymbolize
match mapLookup t.ident env.typeEnv with Some instEntry then
let typeInst = findTypeInstantiation instEntry.polyType ty in
match mapLookup typeInst instEntry.map with Some newId then
TyCon {ident = newId, info = infoTy ty}
TyCon {t with ident = newId, info = infoTy ty}
else
monoError [t.info] "Invalid type constructor instantiation"
else
Expand Down
30 changes: 23 additions & 7 deletions stdlib/mexpr/pprint.mc
Original file line number Diff line number Diff line change
Expand Up @@ -1144,10 +1144,13 @@ lang VariantTypePrettyPrint = VariantTypeAst
-- still use TyVariant in the AST and might get compilation errors for it.
end

lang ConTypePrettyPrint = IdentifierPrettyPrint + ConTypeAst
lang ConTypePrettyPrint = IdentifierPrettyPrint + ConTypeAst + UnknownTypeAst
sem getTypeStringCode (indent : Int) (env: PprintEnv) =
| TyCon t ->
pprintTypeName env t.ident
match pprintTypeName env t.ident with (env, idstr) in
match t.data with TyUnknown {} then (env, idstr) else
match getTypeStringCode indent env t.data with (env, str) in
(env, join [str, ".", idstr])
end

lang VarTypePrettyPrint = IdentifierPrettyPrint + VarTypeAst
Expand All @@ -1157,12 +1160,21 @@ lang VarTypePrettyPrint = IdentifierPrettyPrint + VarTypeAst
end

lang KindPrettyPrint = PrettyPrint + RecordTypeAst + KindAst
sem getKindStringCode (indent : Int) (env : PprintEnv) (idstr : String) =
| Row r ->
sem getKindStringCode (indent : Int) (env : PprintEnv) =
| Record r ->
let recty = TyRecord {info = NoInfo (), fields = r.fields} in
match getTypeStringCode indent env recty with (env, recstr) in
(env, join [init recstr, " ... ", [last recstr]])
| _ -> (env, idstr)
| Data r ->
let tstr =
mapFoldWithKey (lam strs. lam t. lam ks.
snoc strs (join [ nameGetStr t, "{"
, strJoin ", " (map nameGetStr (setToSeq ks))
, "}" ])
) "" r.types in
(env, join ["<", strJoin ", " tstr, ">"])
| Poly () -> (env, "*")
| Mono () -> (env, "o")
end

lang AllTypePrettyPrint = IdentifierPrettyPrint + AllTypeAst + KindPrettyPrint
Expand All @@ -1172,9 +1184,13 @@ lang AllTypePrettyPrint = IdentifierPrettyPrint + AllTypeAst + KindPrettyPrint
sem getTypeStringCode (indent : Int) (env: PprintEnv) =
| TyAll t ->
match pprintVarName env t.ident with (env, idstr) in
match getKindStringCode indent env idstr t.kind with (env, varstr) in
match
match t.kind with Mono () | Poly () then (env, "") else
match getKindStringCode indent env t.kind with (env, kistr) in
(env, concat " :: " kistr)
with (env, kistr) in
match getTypeStringCode indent env t.ty with (env, tystr) in
(env, join ["all ", varstr, ". ", tystr])
(env, join ["all ", idstr, kistr, ". ", tystr])
end

lang AppTypePrettyPrint = PrettyPrint + AppTypeAst
Expand Down
86 changes: 43 additions & 43 deletions stdlib/mexpr/type-check.mc
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,18 @@ lang VarTypeTCUnify = TCUnify + VarTypeAst
else ()
end

lang MetaVarTypeTCUnify = TCUnify + MetaVarTypeUnify + UnifyRows + RecordTypeAst
lang MetaVarTypeTCUnify = TCUnify + MetaVarTypeUnify + UnifyRecords + RecordTypeAst
sem addKinds : Unifier () -> UnifyEnv -> (Kind, Kind) -> Kind
sem addKinds u env =
| (Row r1, Row r2) ->
match unifyRowsUnion u env r1.fields r2.fields with (_, fields) in
Row {r1 with fields = fields}
| (Row _ & rv, ! Row _ & tv)
| (! Row _ & tv, Row _ & rv) -> rv
| (Poly _, Poly _) -> Poly ()
| (s1, s2) -> Mono ()
| (Record r1, Record r2) ->
match unifyRecordsUnion u env r1.fields r2.fields with (_, fields) in
Record {r1 with fields = fields}
| (Data r1, Data r2) ->
Data {r1 with types = mapUnionWith setUnion r1.types r2.types}
| (Mono _ | Poly _, k & !(Mono _ | Poly _)) -> k
| (Poly _, k & (Poly _ | Mono _)) -> k
| (Mono _, Poly _ | Mono _) -> Mono ()
| (k1, k2) -> u.err (Kinds (k1, k2)); error "impossible"

sem unifyMeta u tcenv info env =
| (TyMetaVar t1 & ty1, TyMetaVar t2 & ty2) ->
Expand All @@ -181,9 +183,9 @@ lang MetaVarTypeTCUnify = TCUnify + MetaVarTypeUnify + UnifyRows + RecordTypeAst
| (TyMetaVar t1 & ty1, !TyMetaVar _ & ty2) ->
match deref t1.contents with Unbound tv in
unifyCheck tcenv info tv ty2;
(match (tv.kind, ty2) with (Row r1, TyRecord r2) then
unifyRowsSubset u env r1.fields r2.fields
else match tv.kind with Row _ then u.err (Types (ty1, ty2)) else ());
(match (tv.kind, ty2) with (Record r1, TyRecord r2) then
unifyRecordsSubset u env r1.fields r2.fields
else match tv.kind with Record _ then u.err (Types (ty1, ty2)) else ());
modref t1.contents (Link env.wrappedRhs)

sem unifyCheckBase env info boundVars tv =
Expand Down Expand Up @@ -333,7 +335,7 @@ lang MetaVarDisableGeneralize = MetaVarTypeAst
sem disableRecordGeneralize (lvl : Level) =
| TyMetaVar t & ty ->
switch deref t.contents
case Unbound {kind = Row _} then
case Unbound {kind = Record _} then
weakenMetaVars lvl ty
case Unbound _ then ()
case Link tyL then
Expand Down Expand Up @@ -385,9 +387,6 @@ lang ResolveType = ConTypeAst + AppTypeAst + AliasTypeAst + VariantTypeAst +
else
mkAppTy (resolveType info tycons constr) args

| TyUnknown _ & ty ->
ty

-- If we encounter a TyAlias, it means that the type was already processed by
-- a previous call to typeCheck.
| TyAlias t -> TyAlias t
Expand All @@ -410,7 +409,7 @@ lang RemoveMetaVar = MetaVarTypeAst + UnknownTypeAst + RecordTypeAst
sem removeMetaVarType =
| TyMetaVar t ->
switch deref t.contents
case Unbound {kind = Row x} then
case Unbound {kind = Record x} then
TyRecord {info = t.info, fields = mapMap removeMetaVarType x.fields}
case Unbound _ then TyUnknown { info = t.info }
case Link ty then removeMetaVarType ty
Expand Down Expand Up @@ -680,41 +679,42 @@ lang TypeTypeCheck = TypeCheck + TypeAst + VariantTypeAst + ResolveType
end

lang DataTypeCheck = TypeCheck + DataAst + FunTypeAst + ResolveType
-- NOTE(larshum, 2023-09-07): Verify that the annotated type of a constructor
-- is of the form we expect, and provide understandable error messages
-- otherwise.
sem _checkConstructorType : Info -> Name -> Type -> ()
sem _checkConstructorType info ident =
sem _makeConstructorType : Info -> Name -> Type -> Type
sem _makeConstructorType info ident =
| ty ->
recursive let isValidConstructorType = lam ty.
switch ty
case TyCon _ then true
case TyApp {lhs = lhs} then isValidConstructorType lhs
case _ then false
let msg = lam. join [
"* Invalid type of constructor: ", nameGetStr ident, "\n",
"* The constructor should have a function type, where the\n",
"* right-hand side should refer to a constructor type.\n",
"* When type checking the expression\n"
] in
recursive let substituteData = lam v. lam x.
switch x
case TyCon (t & {data = TyUnknown _}) then
TyCon { t with data = v }
case TyAlias t then
TyAlias { t with content = substituteData v t.content }
case _ then
smap_Type_Type (substituteData v) x
end
in
match inspectType ty with TyArrow {to = to & (TyCon _ | TyApp _)} then
if isValidConstructorType to then ()
else
let msg = join [
"* Invalid type of constructor: ", nameGetStr ident, "\n",
"* The right-hand side should refer to a constructor type.\n",
"* When type checking the expression\n"
] in
errorSingle [info] msg
else
let msg = join [
"* Invalid type of constructor: ", nameGetStr ident, "\n",
"* The constructor should be given type A -> B, where B\n",
" is a fully applied datatype in scope.\n",
"* When type checking the expression\n"
] in
errorSingle [info] msg
match getTypeArgs to with (TyCon t, _) then
let data = Data {
types = mapFromSeq nameCmp [ (t.ident, setOfSeq nameCmp [ ident ]) ]
} in
let x = nameSym "x" in
TyAll { info = info
, ident = x
, kind = data
, ty = substituteData (TyVar {info = info, ident = x}) ty }
else errorSingle [info] (msg ())
else errorSingle [info] (msg ())

sem typeCheckExpr env =
| TmConDef t ->
let tyIdent = resolveType t.info env.tyConEnv t.tyIdent in
_checkConstructorType t.info t.ident tyIdent;
let tyIdent = _makeConstructorType t.info t.ident tyIdent in
let inexpr = typeCheckExpr (_insertCon t.ident tyIdent env) t.inexpr in
TmConDef {t with tyIdent = tyIdent, inexpr = inexpr, ty = tyTm inexpr}
| TmConApp t ->
Expand Down
12 changes: 6 additions & 6 deletions stdlib/mexpr/type-lift.mc
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ lang TypeLiftAddRecordToEnv = TypeLiftBase + RecordTypeAst
| TyRecord {fields = fields, info = info} & ty ->
switch mapLookup fields env.records
case Some name then
let tycon = TyCon {ident = name, info = info} in
let tycon = nitycon_ name info in
(env, tycon)
case None _ then
let name = nameSym "Rec" in
let tycon = TyCon {ident = name, info = info} in
let tycon = nitycon_ name info in
let env = {{env
with records = mapInsert fields name env.records}
with typeEnv = assocSeqInsert name ty env.typeEnv}
Expand All @@ -106,11 +106,11 @@ lang TypeLiftAddSeqToEnv = TypeLiftBase + SeqTypeAst + ConTypeAst
sem addSeqToEnv (env: TypeLiftEnv) =
| TySeq {info = info, ty = innerTy} & ty ->
match mapLookup innerTy env.seqs with Some name then
let tycon = TyCon {ident = name, info = info} in
let tycon = nitycon_ name info in
(env, tycon)
else
let name = nameSym "Seq" in
let tycon = TyCon {ident = name, info = info} in
let tycon = nitycon_ name info in
let env = {{env with seqs = mapInsert innerTy name env.seqs}
with typeEnv = assocSeqInsert name ty env.typeEnv}
in
Expand All @@ -121,11 +121,11 @@ lang TypeLiftAddTensorToEnv = TypeLiftBase + TensorTypeAst + ConTypeAst
sem addTensorToEnv (env : TypeLiftEnv) =
| TyTensor {info = info, ty = innerTy} & ty ->
match mapLookup innerTy env.tensors with Some name then
let tycon = TyCon {ident = name, info = info} in
let tycon = nitycon_ name info in
(env, tycon)
else
let name = nameSym "Tensor" in
let tycon = TyCon {ident = name, info = info} in
let tycon = nitycon_ name info in
let env = {{env with tensors = mapInsert innerTy name env.tensors}
with typeEnv = assocSeqInsert name ty env.typeEnv}
in
Expand Down
Loading

0 comments on commit 91294a2

Please sign in to comment.