Skip to content

Commit

Permalink
chore: add ast::VarPattern::Phi
Browse files Browse the repository at this point in the history
  • Loading branch information
mtshiba committed Oct 20, 2024
1 parent 9c1e77e commit 01a5938
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 26 deletions.
8 changes: 8 additions & 0 deletions crates/erg_common/triple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ impl<T, E> Triple<T, E> {
}
}

pub fn or_else_triple(self, f: impl FnOnce() -> Triple<T, E>) -> Triple<T, E> {
match self {
Triple::None => f(),
Triple::Ok(ok) => Triple::Ok(ok),
Triple::Err(err) => Triple::Err(err),
}
}

pub fn unwrap_or(self, default: T) -> T {
match self {
Triple::None => default,
Expand Down
17 changes: 17 additions & 0 deletions crates/erg_compiler/context/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,23 @@ impl Context {
&& var_params_judge
&& default_check() // contravariant
}
// {Int} <: Obj -> Int
(Subr(_) | Quantified(_), Refinement(refine))
if rhs.singleton_value().is_some() && self.subtype_of(&refine.t, &ClassType) =>
{
let Ok(typ) = self.convert_tp_into_type(rhs.singleton_value().unwrap().clone())
else {
return false;
};
let Some(ctx) = self.get_nominal_type_ctx(&typ) else {
return false;
};
if let Some((_, __call__)) = ctx.get_class_attr("__call__") {
self.supertype_of(lhs, &__call__.t)
} else {
false
}
}
// ?T(<: Int) :> ?U(:> Nat)
// ?T(<: Int) :> ?U(:> Int)
// ?T(<: Nat) !:> ?U(:> Int) (if the upper bound of LHS is smaller than the lower bound of RHS, LHS cannot not be a supertype)
Expand Down
10 changes: 7 additions & 3 deletions crates/erg_compiler/context/initialize/procs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl Context {
)
.quantify();
let t_proc_ret = if PYTHON_MODE { Obj } else { NoneType };
let t_for = nd_proc(
let t_for = proc(
vec![
kw("iterable", poly("Iterable", vec![ty_tp(T.clone())])),
kw(
Expand All @@ -69,6 +69,8 @@ impl Context {
),
],
None,
vec![kw("else!", nd_proc(vec![], None, t_proc_ret.clone()))],
None,
NoneType,
)
.quantify();
Expand All @@ -90,12 +92,14 @@ impl Context {
// not Bool! type because `cond` may be the result of evaluation of a mutable object's method returns Bool.
nd_proc(vec![], None, Bool)
};
let t_while = nd_proc(
let t_while = proc(
vec![
kw("cond!", t_cond),
kw("proc!", nd_proc(vec![], None, t_proc_ret)),
kw("proc!", nd_proc(vec![], None, t_proc_ret.clone())),
],
None,
vec![kw("else!", nd_proc(vec![], None, t_proc_ret.clone()))],
None,
NoneType,
);
let P = mono_q("P", subtypeof(mono("PathLike")));
Expand Down
21 changes: 21 additions & 0 deletions crates/erg_compiler/context/inquire.rs
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,27 @@ impl Context {
Triple::None
}

pub(crate) fn rec_get_param_or_decl_info(&self, name: &str) -> Option<VarInfo> {
if let Some(vi) = self
.params
.iter()
.find(|(var_name, _)| var_name.as_ref().is_some_and(|n| n.inspect() == name))
.map(|(_, vi)| vi)
.or_else(|| self.decls.get(name))
{
return Some(vi.clone());
}
for method_ctx in self.methods_list.iter() {
if let Some(vi) = method_ctx.rec_get_param_or_decl_info(name) {
return Some(vi);
}
}
if let Some(parent) = self.get_outer_scope().or_else(|| self.get_builtins()) {
return parent.rec_get_param_or_decl_info(name);
}
None
}

pub(crate) fn get_attr_info(
&self,
obj: &hir::Expr,
Expand Down
4 changes: 2 additions & 2 deletions crates/erg_compiler/context/register.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl Context {
let mut errs = TyCheckErrors::empty();
let muty = Mutability::from(&sig.inspect().unwrap_or(UBAR)[..]);
let ident = match &sig.pat {
ast::VarPattern::Ident(ident) => ident,
ast::VarPattern::Ident(ident) | ast::VarPattern::Phi(ident) => ident,
ast::VarPattern::Discard(_) | ast::VarPattern::Glob(_) => {
return Ok(());
}
Expand Down Expand Up @@ -287,7 +287,7 @@ impl Context {
None
};
let ident = match &sig.pat {
ast::VarPattern::Ident(ident) => ident,
ast::VarPattern::Ident(ident) | ast::VarPattern::Phi(ident) => ident,
ast::VarPattern::Discard(_) => {
return Ok(VarInfo {
t: body_t.clone(),
Expand Down
2 changes: 2 additions & 0 deletions crates/erg_compiler/context/unify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1874,6 +1874,8 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
}
self.sub_unify_pred(&sub.pred, &supe.pred)?;
}
// {Int} <: Obj -> Int
(Refinement(_), Subr(_) | Quantified(_)) if maybe_sub.singleton_value().is_some() => {}
// {I: Int | I >= 1} <: Nat == {I: Int | I >= 0}
(Refinement(_), sup) => {
let sup = sup.clone().into_refinement();
Expand Down
58 changes: 43 additions & 15 deletions crates/erg_compiler/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2253,6 +2253,11 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
errors.extend(errs);
}
let outer = self.module.context.outer.as_ref().unwrap();
let existing_vi = sig
.ident()
.and_then(|ident| outer.get_current_scope_var(&ident.name))
.cloned();
let existing_t = existing_vi.as_ref().map(|vi| vi.t.clone());
let expect_body_t = sig
.t_spec
.as_ref()
Expand All @@ -2269,14 +2274,14 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
})
.or_else(|| {
sig.ident()
.and_then(|ident| outer.get_current_scope_var(&ident.name))
.and_then(|ident| outer.rec_get_param_or_decl_info(ident.inspect()))
.map(|vi| vi.t.clone())
});
match self.lower_block(body.block, expect_body.or(expect_body_t.as_ref())) {
Ok(block) => {
let found_body_t = block.ref_t();
let ident = match &sig.pat {
ast::VarPattern::Ident(ident) => ident.clone(),
ast::VarPattern::Ident(ident) | ast::VarPattern::Phi(ident) => ident.clone(),
ast::VarPattern::Discard(token) => {
ast::Identifier::private_from_token(token.clone())
}
Expand All @@ -2291,6 +2296,7 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
.map_err(|errs| (None, errors.concat(errs)));
}
};
let mut no_reassign = false;
if let Some(expect_body_t) = expect_body_t {
// TODO: expect_body_t is smaller for constants
// TODO: 定数の場合、expect_body_tのほうが小さくなってしまう
Expand All @@ -2302,20 +2308,35 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
found_body_t,
) {
errors.push(e);
no_reassign = true;
}
}
}
let vi = match self.module.context.outer.as_mut().unwrap().assign_var_sig(
&sig,
found_body_t,
body.id,
block.last(),
None,
) {
Ok(vi) => vi,
Err(errs) => {
errors.extend(errs);
VarInfo::ILLEGAL
let found_body_t = if sig.is_phi() {
self.module
.context
.union(existing_t.as_ref().unwrap_or(&Type::Never), found_body_t)
} else {
found_body_t.clone()
};
let vi = if no_reassign {
VarInfo {
t: found_body_t,
..existing_vi.unwrap_or_default()
}
} else {
match self.module.context.outer.as_mut().unwrap().assign_var_sig(
&sig,
&found_body_t,
body.id,
block.last(),
None,
) {
Ok(vi) => vi,
Err(errs) => {
errors.extend(errs);
VarInfo::ILLEGAL
}
}
};
let ident = hir::Identifier::new(ident, None, vi);
Expand Down Expand Up @@ -2351,7 +2372,7 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
errors.extend(errs);
let found_body_t = block.ref_t();
let ident = match &sig.pat {
ast::VarPattern::Ident(ident) => ident.clone(),
ast::VarPattern::Ident(ident) | ast::VarPattern::Phi(ident) => ident.clone(),
ast::VarPattern::Discard(token) => {
ast::Identifier::private_from_token(token.clone())
}
Expand All @@ -2366,9 +2387,16 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
.map_err(|errs| (None, errors.concat(errs)));
}
};
let found_body_t = if sig.is_phi() {
self.module
.context
.union(existing_t.as_ref().unwrap_or(&Type::Never), found_body_t)
} else {
found_body_t.clone()
};
if let Err(errs) = self.module.context.outer.as_mut().unwrap().assign_var_sig(
&sig,
found_body_t,
&found_body_t,
ast::DefId(0),
None,
None,
Expand Down
17 changes: 13 additions & 4 deletions crates/erg_parser/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4723,6 +4723,10 @@ pub enum VarPattern {
Discard(Token),
Glob(Token),
Ident(Identifier),
/// Used when a different value is assigned in a branch other than `Ident`.
/// (e.g. the else variable when a variable is defined with Python if-else)
/// Not used in Erg mode at this time
Phi(Identifier),
/// e.g. `[x, y, z]` of `[x, y, z] = [1, 2, 3]`
List(VarListPattern),
/// e.g. `(x, y, z)` of `(x, y, z) = (1, 2, 3)`
Expand All @@ -4739,6 +4743,7 @@ impl NestedDisplay for VarPattern {
Self::Discard(_) => write!(f, "_"),
Self::Glob(_) => write!(f, "*"),
Self::Ident(ident) => write!(f, "{ident}"),
Self::Phi(ident) => write!(f, "(phi){ident}"),
Self::List(l) => write!(f, "{l}"),
Self::Tuple(t) => write!(f, "{t}"),
Self::Record(r) => write!(f, "{r}"),
Expand All @@ -4748,9 +4753,9 @@ impl NestedDisplay for VarPattern {
}

impl_display_from_nested!(VarPattern);
impl_locational_for_enum!(VarPattern; Discard, Glob, Ident, List, Tuple, Record, DataPack);
impl_into_py_for_enum!(VarPattern; Discard, Glob, Ident, List, Tuple, Record, DataPack);
impl_from_py_for_enum!(VarPattern; Discard(Token), Glob(Token), Ident(Identifier), List(VarListPattern), Tuple(VarTuplePattern), Record(VarRecordPattern), DataPack(VarDataPackPattern));
impl_locational_for_enum!(VarPattern; Discard, Glob, Ident, Phi, List, Tuple, Record, DataPack);
impl_into_py_for_enum!(VarPattern; Discard, Glob, Ident, Phi, List, Tuple, Record, DataPack);
impl_from_py_for_enum!(VarPattern; Discard(Token), Glob(Token), Ident(Identifier), Phi(Identifier), List(VarListPattern), Tuple(VarTuplePattern), Record(VarRecordPattern), DataPack(VarDataPackPattern));

impl VarPattern {
pub const fn inspect(&self) -> Option<&Str> {
Expand Down Expand Up @@ -4900,10 +4905,14 @@ impl VarSignature {

pub fn ident(&self) -> Option<&Identifier> {
match &self.pat {
VarPattern::Ident(ident) => Some(ident),
VarPattern::Ident(ident) | VarPattern::Phi(ident) => Some(ident),
_ => None,
}
}

pub fn is_phi(&self) -> bool {
matches!(self.pat, VarPattern::Phi(_))
}
}

#[pyclass]
Expand Down
10 changes: 8 additions & 2 deletions crates/erg_parser/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,10 @@ impl Desugarer {
self.desugar_nested_var_pattern(new, rhs, &buf_name, BufIndex::Record(lhs));
}
}
VarPattern::Ident(_) | VarPattern::Discard(_) | VarPattern::Glob(_) => {
VarPattern::Ident(_)
| VarPattern::Phi(_)
| VarPattern::Discard(_)
| VarPattern::Glob(_) => {
if let VarPattern::Ident(ident) = v.pat {
v.pat = VarPattern::Ident(Self::desugar_ident(ident));
}
Expand Down Expand Up @@ -966,7 +969,10 @@ impl Desugarer {
);
}
}
VarPattern::Ident(_) | VarPattern::Discard(_) | VarPattern::Glob(_) => {
VarPattern::Ident(_)
| VarPattern::Phi(_)
| VarPattern::Discard(_)
| VarPattern::Glob(_) => {
let def = Def::new(Signature::Var(sig.clone()), body);
new_module.push(Expr::Def(def));
}
Expand Down

0 comments on commit 01a5938

Please sign in to comment.