Skip to content

Commit

Permalink
fix: trait impl bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Dec 26, 2024
1 parent 9045fa2 commit 017b13f
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 73 deletions.
91 changes: 88 additions & 3 deletions crates/erg_compiler/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ use crate::hir::DefaultParamSignature;
use crate::hir::GlobSignature;
use crate::hir::ListWithLength;
use crate::hir::{
Accessor, Args, BinOp, Block, Call, ClassDef, Def, DefBody, Expr, GuardClause, Identifier,
Lambda, List, Literal, NonDefaultParamSignature, Params, PatchDef, PosArg, ReDef, Record,
Signature, SubrSignature, Tuple, UnaryOp, VarSignature, HIR,
Accessor, Args, BinOp, Block, Call, ClassDef, Def, DefBody, Dict, Expr, GuardClause,
Identifier, Lambda, List, Literal, NonDefaultParamSignature, Params, PatchDef, PosArg, ReDef,
Record, Set, Signature, SubrSignature, Tuple, UnaryOp, VarSignature, HIR,
};
use crate::ty::codeobj::{CodeObj, CodeObjFlags, MakeFunctionFlags};
use crate::ty::value::{GenTypeObj, ValueObj};
Expand Down Expand Up @@ -864,6 +864,51 @@ impl PyCodeGenerator {
self.emit_args_311(args, AccessKind::Name);
return;
}
"list_iterator" => {
let list = Expr::Literal(Literal::new(ValueObj::List(vec![].into()), Token::DUMMY));
let iter = Identifier::static_public("iter");
let iter_call = iter.call(Args::single(PosArg::new(list)));
let typ = Identifier::static_public("type");
let typ_call = typ.call(Args::single(PosArg::new(iter_call.into())));
self.emit_call(typ_call);
return;
}
"set_iterator" => {
let set = Expr::Set(Set::empty());
let iter = Identifier::static_public("iter");
let iter_call = iter.call(Args::single(PosArg::new(set)));
let typ = Identifier::static_public("type");
let typ_call = typ.call(Args::single(PosArg::new(iter_call.into())));
self.emit_call(typ_call);
return;
}
"dict_items" => {
let dict = Expr::Dict(Dict::empty());
let items = Identifier::static_public("iter");
let items_call = items.call(Args::single(PosArg::new(dict)));
let typ = Identifier::static_public("type");
let typ_call = typ.call(Args::single(PosArg::new(items_call.into())));
self.emit_call(typ_call);
return;
}
"dict_keys" => {
let dict = Expr::Dict(Dict::empty());
let keys = Identifier::static_public("keys");
let keys_call = dict.method_call(keys, Args::empty());
let typ = Identifier::static_public("type");
let typ_call = typ.call(Args::single(PosArg::new(keys_call.into())));
self.emit_call(typ_call);
return;
}
"dict_values" => {
let dict = Expr::Dict(Dict::empty());
let values = Identifier::static_public("values");
let values_call = dict.method_call(values, Args::empty());
let typ = Identifier::static_public("type");
let typ_call = typ.call(Args::single(PosArg::new(values_call.into())));
self.emit_call(typ_call);
return;
}
_ => {}
}
let name = self
Expand Down Expand Up @@ -2754,6 +2799,46 @@ impl PyCodeGenerator {
self.emit_load_name_instr(Identifier::private("#sum"));
self.emit_args_311(args, Name);
}
"ListIterator" => {
let list = Expr::Literal(Literal::new(ValueObj::List(vec![].into()), Token::DUMMY));
let iter = Identifier::static_public("iter");
let iter_call = iter.call(Args::single(PosArg::new(list)));
let typ = Identifier::static_public("type");
let typ_call = typ.call(Args::single(PosArg::new(iter_call.into())));
self.emit_call(typ_call);
}
"SetIterator" => {
let set = Expr::Set(Set::empty());
let iter = Identifier::static_public("iter");
let iter_call = iter.call(Args::single(PosArg::new(set)));
let typ = Identifier::static_public("type");
let typ_call = typ.call(Args::single(PosArg::new(iter_call.into())));
self.emit_call(typ_call);
}
"DictItems" => {
let dict = Expr::Dict(Dict::empty());
let iter = Identifier::static_public("iter");
let items_call = iter.call(Args::single(PosArg::new(dict)));
let typ = Identifier::static_public("type");
let typ_call = typ.call(Args::single(PosArg::new(items_call.into())));
self.emit_call(typ_call);
}
"DictKeys" => {
let dict = Expr::Dict(Dict::empty());
let keys = Identifier::static_public("keys");
let keys_call = dict.method_call(keys, Args::empty());
let typ = Identifier::static_public("type");
let typ_call = typ.call(Args::single(PosArg::new(keys_call.into())));
self.emit_call(typ_call);
}
"DictValues" => {
let dict = Expr::Dict(Dict::empty());
let values = Identifier::static_public("values");
let values_call = dict.method_call(values, Args::empty());
let typ = Identifier::static_public("type");
let typ_call = typ.call(Args::single(PosArg::new(values_call.into())));
self.emit_call(typ_call);
}
other if local.ref_t().is_poly_meta_type() && other != "classof" => {
if self.py_version.minor <= Some(9) {
self.load_fake_generic();
Expand Down
74 changes: 49 additions & 25 deletions crates/erg_compiler/context/initialize/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ impl Context {
let mut named = Self::builtin_mono_trait(NAMED, 2);
named.register_builtin_erg_decl(FUNC_NAME, Str, Visibility::BUILTIN_PUBLIC);
let mut sized = Self::builtin_mono_trait(SIZED, 2);
let t = fn0_met(mono(SIZED), Nat).quantify();
let ret_t = if PYTHON_MODE { Int } else { Nat };
let t = fn0_met(mono(SIZED), ret_t).quantify();
sized.register_builtin_erg_decl(FUNDAMENTAL_LEN, t, Visibility::BUILTIN_PUBLIC);
let mut copy = Self::builtin_mono_trait(COPY, 2);
let Slf = mono_q(SELF, subtypeof(mono(COPY)));
Expand Down Expand Up @@ -227,15 +228,24 @@ impl Context {
/* Iterable */
let mut iterable = Self::builtin_poly_trait(ITERABLE, vec![PS::t_nd(TY_T)], 2);
iterable.register_superclass(poly(OUTPUT, vec![ty_tp(T.clone())]), &output);
let Slf = mono_q(SELF, subtypeof(poly(ITERABLE, vec![ty_tp(T.clone())])));
let t = fn0_met(Slf.clone(), proj(Slf, ITER)).quantify();
iterable.register_builtin_decl(
FUNC_ITER,
t,
Visibility::BUILTIN_PUBLIC,
Some(FUNDAMENTAL_ITER),
);
iterable.register_builtin_erg_decl(ITER, Type, Visibility::BUILTIN_PUBLIC);
if PYTHON_MODE {
let t = fn0_met(
poly(ITERABLE, vec![ty_tp(T.clone())]),
poly(ITERATOR, vec![ty_tp(T.clone())]),
)
.quantify();
iterable.register_builtin_erg_decl(FUNDAMENTAL_ITER, t, Visibility::BUILTIN_PUBLIC);
} else {
let Slf = mono_q(SELF, subtypeof(poly(ITERABLE, vec![ty_tp(T.clone())])));
let t = fn0_met(Slf.clone(), proj(Slf, ITER)).quantify();
iterable.register_builtin_decl(
FUNC_ITER,
t,
Visibility::BUILTIN_PUBLIC,
Some(FUNDAMENTAL_ITER),
);
iterable.register_builtin_erg_decl(ITER, Type, Visibility::BUILTIN_PUBLIC);
}
let Slf = poly(ITERABLE, vec![ty_tp(T.clone())]);
let U = type_q(TY_U);
let t_map = fn1_met(
Expand All @@ -244,9 +254,10 @@ impl Context {
poly(MAP, vec![ty_tp(U.clone())]),
)
.quantify();
iterable.register_builtin_decl(
iterable.register_builtin_py_impl(
FUNC_MAP,
t_map,
Mutability::Immutable,
Visibility::BUILTIN_PUBLIC,
Some("Function::iterable_map"),
);
Expand All @@ -268,9 +279,10 @@ impl Context {
)
.quantify();
let t_filter = t_filter.with_default_intersec_index(1);
iterable.register_builtin_decl(
iterable.register_builtin_py_impl(
FUNC_FILTER,
t_filter,
Mutability::Immutable,
Visibility::BUILTIN_PUBLIC,
Some("Function::iterable_filter"),
);
Expand All @@ -279,9 +291,10 @@ impl Context {
vec![TyParam::List(vec![ty_tp(Nat), ty_tp(T.clone())])],
);
let t_enumerate = fn0_met(Slf.clone(), poly(ITERATOR, vec![ty_tp(ret_t)])).quantify();
iterable.register_builtin_decl(
iterable.register_builtin_py_impl(
FUNC_ENUMERATE,
t_enumerate,
Mutability::Immutable,
Visibility::BUILTIN_PUBLIC,
Some("Function::enumerate"),
);
Expand All @@ -291,9 +304,10 @@ impl Context {
poly(ZIP, vec![ty_tp(T.clone()), ty_tp(U.clone())]),
)
.quantify();
iterable.register_builtin_decl(
iterable.register_builtin_py_impl(
FUNC_ZIP,
t_zip,
Mutability::Immutable,
Visibility::BUILTIN_PUBLIC,
Some("Function::zip"),
);
Expand All @@ -304,59 +318,67 @@ impl Context {
T.clone(),
)
.quantify();
iterable.register_builtin_decl(
iterable.register_builtin_py_impl(
FUNC_REDUCE,
t_reduce,
Mutability::Immutable,
Visibility::BUILTIN_PUBLIC,
Some("Function::iterable_reduce"),
);
let t_nth = fn1_met(Slf.clone(), Nat, T.clone()).quantify();
iterable.register_builtin_decl(
iterable.register_builtin_py_impl(
FUNC_NTH,
t_nth,
Mutability::Immutable,
Visibility::BUILTIN_PUBLIC,
Some("Function::iterable_nth"),
);
let t_skip = fn1_met(Slf.clone(), Nat, poly(ITERATOR, vec![ty_tp(T.clone())])).quantify();
iterable.register_builtin_decl(
iterable.register_builtin_py_impl(
FUNC_SKIP,
t_skip,
Mutability::Immutable,
Visibility::BUILTIN_PUBLIC,
Some("Function::iterable_skip"),
);
let t_all = fn1_met(Slf.clone(), func1(T.clone(), Bool), Bool).quantify();
iterable.register_builtin_decl(
iterable.register_builtin_py_impl(
FUNC_ALL,
t_all,
Mutability::Immutable,
Visibility::BUILTIN_PUBLIC,
Some("Function::iterable_all"),
);
let t_any = fn1_met(Slf.clone(), func1(T.clone(), Bool), Bool).quantify();
iterable.register_builtin_decl(
iterable.register_builtin_py_impl(
FUNC_ANY,
t_any,
Mutability::Immutable,
Visibility::BUILTIN_PUBLIC,
Some("Function::iterable_any"),
);
let t_reversed = fn0_met(Slf.clone(), poly(ITERATOR, vec![ty_tp(T.clone())])).quantify();
iterable.register_builtin_decl(
iterable.register_builtin_py_impl(
FUNC_REVERSED,
t_reversed,
Mutability::Immutable,
Visibility::BUILTIN_PUBLIC,
Some("Function::reversed"),
);
let t_position = fn1_met(Slf.clone(), func1(T.clone(), Bool), or(Nat, NoneType)).quantify();
iterable.register_builtin_decl(
iterable.register_builtin_py_impl(
FUNC_POSITION,
t_position,
Mutability::Immutable,
Visibility::BUILTIN_PUBLIC,
Some("Function::iterable_position"),
);
let t_find =
fn1_met(Slf.clone(), func1(T.clone(), Bool), or(T.clone(), NoneType)).quantify();
iterable.register_builtin_decl(
iterable.register_builtin_py_impl(
FUNC_FIND,
t_find,
Mutability::Immutable,
Visibility::BUILTIN_PUBLIC,
Some("Function::iterable_find"),
);
Expand All @@ -369,16 +391,18 @@ impl Context {
poly(ITERATOR, vec![ty_tp(T.clone())]),
)
.quantify();
iterable.register_builtin_decl(
iterable.register_builtin_py_impl(
FUNC_CHAIN,
t_chain,
Mutability::Immutable,
Visibility::BUILTIN_PUBLIC,
Some("Function::iterable_chain"),
);
let t_to_list = fn0_met(Slf.clone(), unknown_len_list_t(T.clone())).quantify();
iterable.register_builtin_decl(
iterable.register_builtin_py_impl(
FUNC_TO_LIST,
t_to_list,
Mutability::Immutable,
Visibility::BUILTIN_PUBLIC,
Some("Function::list"),
);
Expand All @@ -396,7 +420,7 @@ impl Context {
);
/* Container */
let mut container = Self::builtin_poly_trait(CONTAINER, vec![PS::t_nd(TY_T)], 2);
let op_t = fn1_met(mono(CONTAINER), T.clone(), Bool).quantify();
let op_t = fn1_met(poly(CONTAINER, vec![ty_tp(T.clone())]), T.clone(), Bool).quantify();
container.register_superclass(poly(OUTPUT, vec![ty_tp(T.clone())]), &output);
container.register_builtin_erg_decl(FUNDAMENTAL_CONTAINS, op_t, Visibility::BUILTIN_PUBLIC);
/* Collection */
Expand Down
33 changes: 33 additions & 0 deletions crates/erg_compiler/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,10 @@ impl Identifier {
Call::new(Expr::Accessor(Accessor::Ident(self)), None, args)
}

pub fn method_call(self, attr_name: Identifier, args: Args) -> Call {
Call::new(Expr::Accessor(Accessor::Ident(self)), Some(attr_name), args)
}

pub fn is_py_api(&self) -> bool {
self.vi.py_name.is_some()
}
Expand Down Expand Up @@ -1095,6 +1099,14 @@ impl_display_for_enum!(Dict; Normal, Comprehension);
impl_locational_for_enum!(Dict; Normal, Comprehension);
impl_t_for_enum!(Dict; Normal, Comprehension);

impl Dict {
pub fn empty() -> Self {
let l_brace = Token::from_str(TokenKind::LBrace, "{");
let r_brace = Token::from_str(TokenKind::RBrace, "}");
Self::Normal(NormalDict::new(l_brace, r_brace, HashMap::new(), vec![]))
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct NormalSet {
pub l_brace: Token,
Expand Down Expand Up @@ -1194,6 +1206,19 @@ impl_display_for_enum!(Set; Normal, WithLength);
impl_locational_for_enum!(Set; Normal, WithLength);
impl_t_for_enum!(Set; Normal, WithLength);

impl Set {
pub fn empty() -> Self {
let l_brace = Token::from_str(TokenKind::LBrace, "{");
let r_brace = Token::from_str(TokenKind::RBrace, "}");
Self::Normal(NormalSet::new(
l_brace,
r_brace,
Type::Uninited,
Args::empty(),
))
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct RecordAttrs(Vec<Def>);

Expand Down Expand Up @@ -3123,6 +3148,14 @@ impl Expr {
))
}

pub fn method_call(self, attr_name: Identifier, args: Args) -> Call {
Call::new(self, Some(attr_name), args)
}

pub fn method_call_expr(self, attr_name: Identifier, args: Args) -> Self {
Self::Call(self.method_call(attr_name, args))
}

pub fn attr(self, ident: Identifier) -> Accessor {
Accessor::attr(self, ident)
}
Expand Down
Loading

0 comments on commit 017b13f

Please sign in to comment.