diff --git a/codon/compiler/compiler.cpp b/codon/compiler/compiler.cpp index 1465b809..ff297810 100644 --- a/codon/compiler/compiler.cpp +++ b/codon/compiler/compiler.cpp @@ -102,8 +102,9 @@ Compiler::parse(bool isCode, const std::string &file, const std::string &code, auto fo = fopen("_dump_typecheck.sexp", "w"); fmt::print(fo, "{}\n", typechecked->toString(0)); for (auto &f : cache->functions) - for (auto &r : f.second.realizations) + for (auto &r : f.second.realizations) { fmt::print(fo, "{}\n", r.second->ast->toString(0)); + } fclose(fo); } diff --git a/codon/parser/ast/stmt.cpp b/codon/parser/ast/stmt.cpp index a538a703..02b86f10 100644 --- a/codon/parser/ast/stmt.cpp +++ b/codon/parser/ast/stmt.cpp @@ -154,12 +154,12 @@ ForStmt::ForStmt(ExprPtr var, ExprPtr iter, StmtPtr suite, StmtPtr elseSuite, ExprPtr decorator, std::vector ompArgs) : Stmt(), var(std::move(var)), iter(std::move(iter)), suite(std::move(suite)), elseSuite(std::move(elseSuite)), decorator(std::move(decorator)), - ompArgs(std::move(ompArgs)), wrapped(false) {} + ompArgs(std::move(ompArgs)), wrapped(false), flat(false) {} ForStmt::ForStmt(const ForStmt &stmt) : Stmt(stmt), var(ast::clone(stmt.var)), iter(ast::clone(stmt.iter)), suite(ast::clone(stmt.suite)), elseSuite(ast::clone(stmt.elseSuite)), decorator(ast::clone(stmt.decorator)), ompArgs(ast::clone_nop(stmt.ompArgs)), - wrapped(stmt.wrapped) {} + wrapped(stmt.wrapped), flat(stmt.flat) {} std::string ForStmt::toString(int indent) const { std::string pad = indent > 0 ? ("\n" + std::string(indent + INDENT_SIZE, ' ')) : " "; std::string attr; diff --git a/codon/parser/ast/stmt.h b/codon/parser/ast/stmt.h index 293ed5fa..45596ce7 100644 --- a/codon/parser/ast/stmt.h +++ b/codon/parser/ast/stmt.h @@ -265,6 +265,8 @@ struct ForStmt : public Stmt { /// Indicates if iter was wrapped with __iter__() call. bool wrapped; + /// True if there are no break/continue within the loop + bool flat; ForStmt(ExprPtr var, ExprPtr iter, StmtPtr suite, StmtPtr elseSuite = nullptr, ExprPtr decorator = nullptr, std::vector ompArgs = {}); diff --git a/codon/parser/visitors/simplify/ctx.h b/codon/parser/visitors/simplify/ctx.h index 31c11bf7..b0f14c9e 100644 --- a/codon/parser/visitors/simplify/ctx.h +++ b/codon/parser/visitors/simplify/ctx.h @@ -129,6 +129,9 @@ struct SimplifyContext : public Context { /// List of variables "seen" before their assignment within a loop. /// Used to dominate variables that are updated within a loop. std::unordered_set seenVars; + /// False if a loop has continue/break statement. Used for flattening static + /// loops. + bool flat = true; }; std::vector loops; diff --git a/codon/parser/visitors/simplify/loops.cpp b/codon/parser/visitors/simplify/loops.cpp index 0fdb71c0..672c00b4 100644 --- a/codon/parser/visitors/simplify/loops.cpp +++ b/codon/parser/visitors/simplify/loops.cpp @@ -18,6 +18,7 @@ namespace codon::ast { void SimplifyVisitor::visit(ContinueStmt *stmt) { if (!ctx->getBase()->getLoop()) E(Error::EXPECTED_LOOP, stmt, "continue"); + ctx->getBase()->getLoop()->flat = false; } /// Ensure that `break` is in a loop. @@ -28,6 +29,7 @@ void SimplifyVisitor::visit(ContinueStmt *stmt) { void SimplifyVisitor::visit(BreakStmt *stmt) { if (!ctx->getBase()->getLoop()) E(Error::EXPECTED_LOOP, stmt, "break"); + ctx->getBase()->getLoop()->flat = false; if (!ctx->getBase()->getLoop()->breakVar.empty()) { resultStmt = N( transform(N(N(ctx->getBase()->getLoop()->breakVar), @@ -116,6 +118,8 @@ void SimplifyVisitor::visit(ForStmt *stmt) { stmt->suite = transform(N(stmts)); } + if (ctx->getBase()->getLoop()->flat) + stmt->flat = true; // Complete while-else clause if (stmt->elseSuite && stmt->elseSuite->firstInBlock()) { resultStmt = N(assign, N(*stmt), diff --git a/codon/parser/visitors/typecheck/access.cpp b/codon/parser/visitors/typecheck/access.cpp index 3411fc44..174cf294 100644 --- a/codon/parser/visitors/typecheck/access.cpp +++ b/codon/parser/visitors/typecheck/access.cpp @@ -326,11 +326,15 @@ ExprPtr TypecheckVisitor::getClassMember(DotExpr *expr, // Case: transform `union.m` to `__internal__.get_union_method(union, "m", ...)` if (typ->getUnion()) { + if (!typ->canRealize()) + return nullptr; // delay! + // bool isMember = false; + // for (auto &t: typ->getUnion()->getRealizationTypes()) + // if (ctx->findMethod(t.get(), expr->member).empty()) return transform(N( - N("__internal__.get_union_method:0"), + N("__internal__.union_member:0"), std::vector{{"union", expr->expr}, - {"method", N(expr->member)}, - {"", N(EllipsisExpr::PARTIAL)}})); + {"member", N(expr->member)}})); } // For debugging purposes: diff --git a/codon/parser/visitors/typecheck/call.cpp b/codon/parser/visitors/typecheck/call.cpp index 0da5158e..2db58451 100644 --- a/codon/parser/visitors/typecheck/call.cpp +++ b/codon/parser/visitors/typecheck/call.cpp @@ -738,6 +738,8 @@ ExprPtr TypecheckVisitor::transformIsInstance(CallExpr *expr) { return transform(N(typ->getRecord() != nullptr)); } else if (typExpr->isId("ByRef")) { return transform(N(typ->getRecord() == nullptr)); + } else if (typExpr->isId("Union")) { + return transform(N(typ->getUnion() != nullptr)); } else if (!typExpr->type->getUnion() && typ->getUnion()) { auto unionTypes = typ->getUnion()->getRealizationTypes(); int tag = -1; @@ -997,10 +999,6 @@ std::pair TypecheckVisitor::transformInternalStaticFn(CallExpr *e if (!typ) return {true, nullptr}; - auto fn = expr->args[0].value->type->getFunc(); - if (!fn) - error("expected a function, got '{}'", expr->args[0].value->type->prettyString()); - auto inargs = unpackTupleTypes(expr->args[1].value); auto kwargs = unpackTupleTypes(expr->args[2].value); seqassert(inargs && kwargs, "bad call to fn_can_call"); @@ -1014,6 +1012,25 @@ std::pair TypecheckVisitor::transformInternalStaticFn(CallExpr *e callArgs.push_back({a.first, std::make_shared()}); // dummy expression callArgs.back().value->setType(a.second); } + + auto fn = expr->args[0].value->type->getFunc(); + if (!fn) { + bool canCompile = true; + // Special case: not a function, just try compiling it! + auto ocache = *(ctx->cache); + auto octx = *ctx; + try { + transform(N(clone(expr->args[0].value), + N(clone(expr->args[1].value)), + N(clone(expr->args[2].value)))); + } catch (const exc::ParserException &e) { + // LOG("{}", e.what()); + canCompile = false; + *ctx = octx; + *(ctx->cache) = ocache; + } + return {true, transform(N(canCompile))}; + } return {true, transform(N(canCall(fn, callArgs) >= 0))}; } else if (expr->expr->isId("std.internal.static.fn_arg_has_type")) { expr->staticValue.type = StaticValue::INT; diff --git a/codon/parser/visitors/typecheck/class.cpp b/codon/parser/visitors/typecheck/class.cpp index d745c52d..d97bff83 100644 --- a/codon/parser/visitors/typecheck/class.cpp +++ b/codon/parser/visitors/typecheck/class.cpp @@ -178,14 +178,20 @@ std::string TypecheckVisitor::generateTuple(size_t len, const std::string &name, if (startswith(typeName, TYPE_KWTUPLE)) stmt->getClass()->suite = N(getItem, contains, getDef); - // Add repr for KwArgs: + // Add repr and call for partials: // `def __repr__(self): return __magic__.repr_partial(self)` auto repr = N( "__repr__", nullptr, std::vector{Param{"self"}}, N(N(N( N(N("__magic__"), "repr_partial"), N("self"))))); + auto pcall = N( + "__call__", nullptr, + std::vector{Param{"self"}, Param{"*args"}, Param{"**kwargs"}}, + N( + N(N(N("self"), N(N("args")), + N(N("kwargs")))))); if (startswith(typeName, TYPE_PARTIAL)) - stmt->getClass()->suite = repr; + stmt->getClass()->suite = N(repr, pcall); // Simplify in the standard library context and type check stmt = SimplifyVisitor::apply(ctx->cache->imports[STDLIB_IMPORT].ctx, stmt, diff --git a/codon/parser/visitors/typecheck/error.cpp b/codon/parser/visitors/typecheck/error.cpp index 8385ecbe..4fa74012 100644 --- a/codon/parser/visitors/typecheck/error.cpp +++ b/codon/parser/visitors/typecheck/error.cpp @@ -27,6 +27,22 @@ using namespace types; /// f = exc; ...; break # PyExc /// raise``` void TypecheckVisitor::visit(TryStmt *stmt) { + // TODO: static can-compile check + // if (stmt->catches.size() == 1 && stmt->catches[0].var.empty() && + // stmt->catches[0].exc->isId("std.internal.types.error.StaticCompileError")) { + // /// TODO: this is right now _very_ dangerous; inferred types here will remain! + // bool compiled = true; + // try { + // auto nctx = std::make_shared(*ctx); + // TypecheckVisitor(nctx).transform(clone(stmt->suite)); + // } catch (const exc::ParserException &exc) { + // compiled = false; + // } + // resultStmt = compiled ? transform(stmt->suite) : + // transform(stmt->catches[0].suite); LOG("testing!! {} {}", getSrcInfo(), + // compiled); return; + // } + ctx->blockLevel++; transform(stmt->suite); ctx->blockLevel--; diff --git a/codon/parser/visitors/typecheck/function.cpp b/codon/parser/visitors/typecheck/function.cpp index b687d57c..9139a73f 100644 --- a/codon/parser/visitors/typecheck/function.cpp +++ b/codon/parser/visitors/typecheck/function.cpp @@ -160,6 +160,10 @@ types::FuncTypePtr TypecheckVisitor::makeFunctionType(FunctionStmt *stmt) { ctx->typecheckLevel++; if (stmt->ret) { unify(baseType->generics[1].type, transformType(stmt->ret)->getType()); + if (stmt->ret->isId("Union")) { + baseType->generics[1].type->getUnion()->generics[0].type->getUnbound()->kind = + LinkType::Generic; + } } else { generics.push_back(unify(baseType->generics[1].type, ctx->getUnbound())); } diff --git a/codon/parser/visitors/typecheck/infer.cpp b/codon/parser/visitors/typecheck/infer.cpp index 2e91ad9c..2cddf3bc 100644 --- a/codon/parser/visitors/typecheck/infer.cpp +++ b/codon/parser/visitors/typecheck/infer.cpp @@ -89,6 +89,18 @@ StmtPtr TypecheckVisitor::inferTypes(StmtPtr result, bool isToplevel) { } if (result->isDone()) { + // Special union case: if union cannot be inferred return type is Union[NoneType] + if (auto tr = ctx->getRealizationBase()->returnType) { + if (auto tu = tr->getUnion()) { + if (!tu->isSealed()) { + if (tu->pendingTypes[0]->getLink() && + tu->pendingTypes[0]->getLink()->kind == LinkType::Unbound) { + tu->addType(ctx->forceFind("NoneType")->type); + tu->seal(); + } + } + } + } break; } else if (changedNodes) { continue; @@ -353,6 +365,10 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type, bool force) if (!ret) { realizations.erase(key); + ctx->realizationBases.pop_back(); + ctx->popBlock(); + ctx->typecheckLevel--; + getLogger().level--; if (!startswith(ast->name, "._lambda")) { // Lambda typecheck failures are "ignored" as they are treated as statements, // not functions. @@ -360,10 +376,6 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type, bool force) // LOG("{}", ast->suite->toString(2)); error("cannot typecheck the program"); } - ctx->realizationBases.pop_back(); - ctx->popBlock(); - ctx->typecheckLevel--; - getLogger().level--; return nullptr; // inference must be delayed } @@ -836,127 +848,19 @@ TypecheckVisitor::generateSpecialAst(types::FuncType *type) { N("__internal__.new_union:0"), N(type->ast->args[0].name), N(unionType->realizedTypeName()))); ast->suite = suite; - } else if (startswith(ast->name, "__internal__.new_union:0")) { - // Special case: __internal__.new_union - // def __internal__.new_union(value, U[T0, ..., TN]): - // if isinstance(value, T0): - // return __internal__.union_make(0, value, U[T0, ..., TN]) - // if isinstance(value, Union[T0]): - // return __internal__.union_make( - // 0, __internal__.get_union(value, T0), U[T0, ..., TN]) - // ... ... - // compile_error("invalid union constructor") - auto unionType = type->funcGenerics[0].type->getUnion(); - auto unionTypes = unionType->getRealizationTypes(); - - auto objVar = ast->args[0].name; - auto suite = N(); - int tag = 0; - for (auto &t : unionTypes) { - suite->stmts.push_back(N( - N(N("isinstance"), N(objVar), - NT(t->realizedName())), - N(N(N("__internal__.union_make:0"), - N(tag), N(objVar), - N(unionType->realizedTypeName()))))); - // Check for Union[T] - suite->stmts.push_back(N( - N( - N("isinstance"), N(objVar), - NT(NT("Union"), - std::vector{NT(t->realizedName())})), - N( - N(N("__internal__.union_make:0"), N(tag), - N(N("__internal__.get_union:0"), - N(objVar), NT(t->realizedName())), - N(unionType->realizedTypeName()))))); - tag++; - } - suite->stmts.push_back(N(N( - N("compile_error"), N("invalid union constructor")))); - ast->suite = suite; - } else if (startswith(ast->name, "__internal__.get_union:0")) { - // Special case: __internal__.get_union - // def __internal__.new_union(union: Union[T0,...,TN], T): - // if __internal__.union_get_tag(union) == 0: - // return __internal__.union_get_data(union, T0) - // ... - // raise TypeError("getter") - auto unionType = type->getArgTypes()[0]->getUnion(); - auto unionTypes = unionType->getRealizationTypes(); - - auto targetType = type->funcGenerics[0].type; - auto selfVar = ast->args[0].name; - auto suite = N(); - int tag = 0; - for (auto t : unionTypes) { - if (t->realizedName() == targetType->realizedName()) { - suite->stmts.push_back(N( - N(N(N("__internal__.union_get_tag:0"), - N(selfVar)), - "==", N(tag)), - N(N(N("__internal__.union_get_data:0"), - N(selfVar), - NT(t->realizedName()))))); - } - tag++; - } - suite->stmts.push_back( - N(N(N("std.internal.types.error.TypeError"), - N("invalid union getter")))); - ast->suite = suite; - } else if (startswith(ast->name, "__internal__._get_union_method:0")) { - // def __internal__._get_union_method(union: Union[T0,...,TN], method, *args, **kw): - // if __internal__.union_get_tag(union) == 0: - // return __internal__.union_get_data(union, T0).method(*args, **kw) - // ... - // raise TypeError("call") - auto szt = type->funcGenerics[0].type->getStatic(); - auto fnName = szt->evaluate().getString(); - auto unionType = type->getArgTypes()[0]->getUnion(); - auto unionTypes = unionType->getRealizationTypes(); - - auto selfVar = ast->args[0].name; - auto suite = N(); - int tag = 0; - for (auto &t : unionTypes) { - auto callee = - N(N(N("__internal__.union_get_data:0"), - N(selfVar), NT(t->realizedName())), - fnName); - auto args = N(N(ast->args[2].name.substr(1))); - auto kwargs = N(N(ast->args[3].name.substr(2))); - std::vector callArgs; - ExprPtr check = - N(N("hasattr"), NT(t->realizedName()), - N(fnName), args->clone(), kwargs->clone()); - suite->stmts.push_back(N( - N( - check, "&&", - N(N(N("__internal__.union_get_tag:0"), - N(selfVar)), - "==", N(tag))), - N(N(N(callee, args, kwargs))))); - tag++; - } - suite->stmts.push_back( - N(N(N("std.internal.types.error.TypeError"), - N("invalid union call")))); - // suite->stmts.push_back(N(N())); - - auto ret = ctx->instantiate(ctx->getType("Union")); - unify(type->getRetType(), ret); - ast->suite = suite; - } else if (startswith(ast->name, "__internal__.get_union_first:0")) { - // def __internal__.get_union_first(union: Union[T0]): + } else if (startswith(ast->name, "__internal__.get_union_tag:0")) { + // def __internal__.get_union_tag(union: Union, tag: Static[int]): // return __internal__.union_get_data(union, T0) + auto szt = type->funcGenerics[0].type->getStatic(); + auto tag = szt->evaluate().getInt(); auto unionType = type->getArgTypes()[0]->getUnion(); auto unionTypes = unionType->getRealizationTypes(); - + if (tag < 0 || tag >= unionTypes.size()) + E(Error::CUSTOM, getSrcInfo(), "bad union tag"); auto selfVar = ast->args[0].name; auto suite = N(N( N(N("__internal__.union_get_data:0"), N(selfVar), - NT(unionTypes[0]->realizedName())))); + NT(unionTypes[tag]->realizedName())))); ast->suite = suite; } return ast; diff --git a/codon/parser/visitors/typecheck/loops.cpp b/codon/parser/visitors/typecheck/loops.cpp index 0bb4b206..7c37e400 100644 --- a/codon/parser/visitors/typecheck/loops.cpp +++ b/codon/parser/visitors/typecheck/loops.cpp @@ -164,6 +164,7 @@ StmtPtr TypecheckVisitor::transformHeterogenousTupleFor(ForStmt *stmt) { /// while loop: /// i = x; ; break /// loop = False # also set to False on break +/// If a loop is flat, while wrappers are removed. /// A separate suite is generated for each static iteration. StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) { auto var = stmt->var->getId()->value; @@ -188,13 +189,19 @@ StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) { if (vars.size() > 1) vars.erase(vars.begin()); auto [ok, items] = transformStaticLoopCall(vars, stmt->iter, [&](StmtPtr assigns) { - auto brk = N(); - brk->setDone(); // Avoid transforming this one to continue - // var [: Static] := expr; suite... - auto loop = N(N(loopVar), - N(assigns, clone(stmt->suite), brk)); - loop->gotoVar = loopVar; - return loop; + StmtPtr ret = nullptr; + if (!stmt->flat) { + auto brk = N(); + brk->setDone(); // Avoid transforming this one to continue + // var [: Static] := expr; suite... + auto loop = N(N(loopVar), + N(assigns, clone(stmt->suite), brk)); + loop->gotoVar = loopVar; + ret = loop; + } else { + ret = N(assigns, clone(stmt->suite)); + } + return ret; }); if (!ok) { if (oldSuite) @@ -203,17 +210,21 @@ StmtPtr TypecheckVisitor::transformStaticForLoop(ForStmt *stmt) { } // Close the loop - ctx->blockLevel++; - auto a = N(N(loopVar), N(false)); - a->setUpdate(); auto block = N(); for (auto &i : items) block->stmts.push_back(std::dynamic_pointer_cast(i)); - block->stmts.push_back(a); - auto loop = - transform(N(N(N(loopVar), N(true)), - N(N(loopVar), block))); - ctx->blockLevel--; + StmtPtr loop = nullptr; + if (!stmt->flat) { + ctx->blockLevel++; + auto a = N(N(loopVar), N(false)); + a->setUpdate(); + block->stmts.push_back(a); + loop = transform(N(N(N(loopVar), N(true)), + N(N(loopVar), block))); + ctx->blockLevel--; + } else { + loop = transform(block); + } return loop; } @@ -310,17 +321,18 @@ TypecheckVisitor::transformStaticLoopCall( error("expected two items"); if (auto fna = ctx->getFunctionArgs(fn->type)) { auto [generics, args] = *fna; - auto typ = args[0]->getRecord(); - if (!typ) + if (auto typ = args[0]->getRecord()) { + for (size_t i = 0; i < typ->args.size(); i++) { + auto b = N( + {N(N(vars[0]), N(i), + NT(NT("Static"), NT("int"))), + N(N(vars[1]), + N(iter->getCall()->args[0].value->clone(), + N(i)))}); + block.push_back(wrap(b)); + } + } else { error("staticenumerate needs a tuple"); - for (size_t i = 0; i < typ->args.size(); i++) { - auto b = N( - {N(N(vars[0]), N(i), - NT(NT("Static"), NT("int"))), - N(N(vars[1]), - N(iter->getCall()->args[0].value->clone(), - N(i)))}); - block.push_back(wrap(b)); } } else { error("bad call to staticenumerate"); @@ -369,21 +381,38 @@ TypecheckVisitor::transformStaticLoopCall( seqassert(typ, "vars_types expects a realizable type, got '{}' instead", generics[0]); - size_t idx = 0; - for (auto &f : getClassFields(typ->getClass().get())) { - auto ta = realize(ctx->instantiate(f.type, typ->getClass())); - seqassert(ta, "cannot realize '{}'", f.type->debugString(1)); - std::vector stmts; - if (withIdx) { + + if (auto utyp = typ->getUnion()) { + for (size_t i = 0; i < utyp->getRealizationTypes().size(); i++) { + std::vector stmts; + if (withIdx) { + stmts.push_back( + N(N(vars[0]), N(i), + NT(NT("Static"), NT("int")))); + } stmts.push_back( - N(N(vars[0]), N(idx), - NT(NT("Static"), NT("int")))); + N(N(vars[1]), + N(utyp->getRealizationTypes()[i]->realizedName()))); + auto b = N(stmts); + block.push_back(wrap(b)); + } + } else { + size_t idx = 0; + for (auto &f : getClassFields(typ->getClass().get())) { + auto ta = realize(ctx->instantiate(f.type, typ->getClass())); + seqassert(ta, "cannot realize '{}'", f.type->debugString(1)); + std::vector stmts; + if (withIdx) { + stmts.push_back( + N(N(vars[0]), N(idx), + NT(NT("Static"), NT("int")))); + } + stmts.push_back( + N(N(vars[withIdx]), NT(ta->realizedName()))); + auto b = N(stmts); + block.push_back(wrap(b)); + idx++; } - stmts.push_back( - N(N(vars[withIdx]), NT(ta->realizedName()))); - auto b = N(stmts); - block.push_back(wrap(b)); - idx++; } } else { error("bad call to vars"); diff --git a/codon/parser/visitors/typecheck/typecheck.cpp b/codon/parser/visitors/typecheck/typecheck.cpp index 93030123..c7eb9566 100644 --- a/codon/parser/visitors/typecheck/typecheck.cpp +++ b/codon/parser/visitors/typecheck/typecheck.cpp @@ -423,8 +423,9 @@ bool TypecheckVisitor::wrapExpr(ExprPtr &expr, const TypePtr &expectedType, } } else if (exprClass && expectedClass && expectedClass->getUnion()) { // Make union types via __internal__.new_union - if (!expectedClass->getUnion()->isSealed()) + if (!expectedClass->getUnion()->isSealed()) { expectedClass->getUnion()->addType(exprClass); + } if (auto t = realize(expectedClass)) { if (expectedClass->unify(exprClass.get(), nullptr) == -1) expr = transform(N(N("__internal__.new_union:0"), expr, diff --git a/stdlib/internal/core.codon b/stdlib/internal/core.codon index b2febe89..d761f1c7 100644 --- a/stdlib/internal/core.codon +++ b/stdlib/internal/core.codon @@ -200,6 +200,8 @@ class Union[TU]: # compiler-generated def __new__(val): TU + def __call__(self, *args, **kwargs): + return __internal__.union_call(self, args, kwargs) # dummy @__internal__ diff --git a/stdlib/internal/internal.codon b/stdlib/internal/internal.codon index 87a5dfc4..a40cf100 100644 --- a/stdlib/internal/internal.codon +++ b/stdlib/internal/internal.codon @@ -4,7 +4,7 @@ from internal.gc import ( alloc, alloc_atomic, alloc_atomic_uncollectable, free, sizeof, register_finalizer ) -from internal.static import vars_types, tuple_type, vars as _vars +from internal.static import vars_types, tuple_type, vars as _vars, fn_overloads, fn_can_call def vars(obj, with_index: Static[int] = 0): return _vars(obj, with_index) @@ -148,6 +148,9 @@ class __internal__: # Unions + def get_union_tag(u, tag: Static[int]): # compiler-generated + pass + @llvm def union_set_tag(tag: byte, U: type) -> U: %0 = insertvalue {=U} undef, i8 %tag, 0 @@ -172,22 +175,59 @@ class __internal__: return u def new_union(value, U: type) -> U: - pass + for tag, T in vars_types(U, with_index=1): + if isinstance(value, T): + return __internal__.union_make(tag, value, U) + if isinstance(value, Union[T]): + return __internal__.union_make(tag, __internal__.get_union(value, T), U) + # TODO: make this static! + raise TypeError("invalid union constructor") def get_union(union, T: type) -> T: - pass - - def get_union_first(union): - pass - - def _get_union_method(union, method: Static[str], *args, **kwargs): - pass - - def get_union_method(union, method: Static[str], *args, **kwargs): - t = __internal__._get_union_method(union, method, *args, **kwargs) + for tag, TU in vars_types(union, with_index=1): + if isinstance(TU, T): + if __internal__.union_get_tag(union) == tag: + return __internal__.union_get_data(union, TU) + raise TypeError(f"invalid union getter for type '{T.__class__.__name__}'") + + def _union_member_helper(union, member: Static[str]) -> Union: + for tag, T in vars_types(union, with_index=1): + if hasattr(T, member): + if __internal__.union_get_tag(union) == tag: + return getattr(__internal__.union_get_data(union, T), member) + raise TypeError(f"invalid union call '{member}'") + + def union_member(union, member: Static[str]): + t = __internal__._union_member_helper(union, member) if staticlen(t) == 1: - return __internal__.get_union_first(t) - return t + return __internal__.get_union_tag(t, 0) + else: + return t + + def _union_call_helper(union, args, kwargs) -> Union: + for tag, T in vars_types(union, with_index=1): + if hasattr(T, '__call__'): + if fn_can_call(__internal__.union_get_data(union, T), *args, **kwargs): + if __internal__.union_get_tag(union) == tag: + return __internal__.union_get_data(union, T).__call__(*args, **kwargs) + raise TypeError("cannot call union") + + def union_call(union, args, kwargs): + t = __internal__._union_call_helper(union, args, kwargs) + if staticlen(t) == 1: + return __internal__.get_union_tag(t, 0) + else: + return t + + def union_str(union): + for tag, T in vars_types(union, with_index=1): + if hasattr(T, '__str__'): + if __internal__.union_get_tag(union) == tag: + return __internal__.union_get_data(union, T).__str__() + elif hasattr(T, '__repr__'): + if __internal__.union_get_tag(union) == tag: + return __internal__.union_get_data(union, T).__repr__() + return '' # Tuples diff --git a/stdlib/internal/static.codon b/stdlib/internal/static.codon index 2d974d9c..565a5746 100644 --- a/stdlib/internal/static.codon +++ b/stdlib/internal/static.codon @@ -39,4 +39,3 @@ def vars_types(T: type, with_index: Static[int] = 0): def tuple_type(T: type, N: Static[int]): pass - diff --git a/stdlib/internal/types/error.codon b/stdlib/internal/types/error.codon index a0f8ea90..f0bb722a 100644 --- a/stdlib/internal/types/error.codon +++ b/stdlib/internal/types/error.codon @@ -167,3 +167,7 @@ class SystemExit(Static[BaseException]): @property def status(self): return self._status + +class StaticCompileError(Static[Exception]): + def __init__(self, message: str = ""): + super().__init__("StaticCompileError", message) diff --git a/stdlib/internal/types/str.codon b/stdlib/internal/types/str.codon index aea45b06..1d5ec054 100644 --- a/stdlib/internal/types/str.codon +++ b/stdlib/internal/types/str.codon @@ -18,7 +18,9 @@ class str: return str(Ptr[byte](), 0) def __new__(what) -> str: - if hasattr(what, "__str__"): + if isinstance(what, Union): + return __internal__.union_str(what) + elif hasattr(what, "__str__"): return what.__str__() else: return what.__repr__() diff --git a/test/parser/types.codon b/test/parser/types.codon index 1efb0393..1cb612bd 100644 --- a/test/parser/types.codon +++ b/test/parser/types.codon @@ -1443,10 +1443,10 @@ print(a, b, a.__class__.__name__, b.__class__.__name__) #: 11 70 Union[int,str] if True: a = 'hello' foo_str(a) #: hello str - foo(a) #: 'hello' 1 Union[int,str] - # b = a[1:3] - # print(b) -print(a) #: 'hello' + foo(a) #: hello 1 Union[int,str] + b = a[1:3] + print(b) #: el +print(a) #: hello a: Union[Union[Union[str], int], Union[int, int, str]] = 9 foo(a) #: 9 0 Union[int,str] @@ -1460,7 +1460,7 @@ def ret(x): r = ret(2) print(r, r.__class__.__name__) #: False Union[bool,int,str] r = ret(33.3) -print(r, r.__class__.__name__) #: 'oops' Union[bool,float,int,str] +print(r, r.__class__.__name__) #: oops Union[bool,float,int,str] def ret2(x) -> Union: if x < 1: return 1 @@ -1477,11 +1477,17 @@ class B: y: str def foo(self): return f"B: {self.y}" -x : Union[A,B] = A(5) # TODO: just Union does not work :/ +x : Union[A,B] = A(5) # TODO: just Union does not work in test mode :/ print(x.foo()) #: A: 5 +print(x.x) #: 5 if True: x = B("bee") print(x.foo()) #: B: bee +print(x.y) #: bee +try: + print(x.x) +except TypeError as e: + print(e.message) #: invalid union call 'x' def do(x: A): print('do', x.x) @@ -1527,7 +1533,7 @@ try: a = "foo" print(a == 123) except TypeError: - print("oops", a) #: oops 'foo' + print("oops", a) #: oops foo #%% generator_capture_nonglobal,barebones