diff --git a/.gitignore b/.gitignore
index 4ca31db..ee12072 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,3 +5,5 @@ beniget.egg-info
.pytest_cache
.vscode
+build
+.tox
\ No newline at end of file
diff --git a/README.rst b/README.rst
index 85746ca..426b431 100644
--- a/README.rst
+++ b/README.rst
@@ -4,8 +4,8 @@ Gast, Beniget!
Beniget is a collection of Compile-time analyse on Python Abstract Syntax Tree(AST).
It's a building block to write static analyzer or compiler for Python.
-Beniget relies on `gast `_ to provide a cross
-version abstraction of the AST, effectively working across all Python 3 versions greater than 3.6.
+Beniget is compatible with the standard library AST as well as the AST generated by `gast `_,
+which is a cross version abstraction of the AST.
API
---
@@ -16,6 +16,12 @@ Basically Beniget provides three analyse:
- ``beniget.DefUseChains`` that maps each node to the list of definition points in that node;
- ``beniget.UseDefChains`` that maps each node to the list of possible definition of that node.
+Alternatives working with standard library AST:
+
+- ``beniget.standard.Ancestors`` that maps each node to the list of enclosing nodes;
+- ``beniget.standard.DefUseChains`` that maps each node to the list of definition points in that node;
+- ``beniget.standard.UseDefChains`` that maps each node to the list of possible definition of that node.
+
See sample usages and/or run ``pydoc beniget`` for more information :-).
diff --git a/beniget/beniget.py b/beniget/beniget.py
index 9f45161..29b1edb 100644
--- a/beniget/beniget.py
+++ b/beniget/beniget.py
@@ -1,30 +1,52 @@
from collections import defaultdict
from contextlib import contextmanager
+import itertools
import sys
+import ast as _ast
import gast as ast
from .ordered_set import ordered_set
+_ClassOrFunction = set(('ClassDef', 'FunctionDef', 'AsyncFunctionDef'))
+_Comp = set(('DictComp', 'ListComp', 'SetComp', 'GeneratorExp'))
+_ClosedScopes = set(('FunctionDef', 'AsyncFunctionDef',
+ 'Lambda', 'DictComp', 'ListComp',
+ 'SetComp', 'GeneratorExp', 'def695'))
+_TypeVarLike = set(('TypeVar', 'TypeVarTuple', 'ParamSpec'))
+_HasName = set((*_ClassOrFunction, *_TypeVarLike))
+
class Ancestors(ast.NodeVisitor):
"""
Build the ancestor tree, that associates a node to the list of node visited
from the root node (the Module) to the current node
-
>>> import gast as ast
>>> code = 'def foo(x): return x + 1'
>>> module = ast.parse(code)
-
>>> from beniget import Ancestors
>>> ancestors = Ancestors()
>>> ancestors.visit(module)
-
>>> binop = module.body[0].body[0].value
>>> for n in ancestors.parents(binop):
... print(type(n))
+
+ Also works with standard library nodes
+
+ >>> import ast as _ast
+ >>> code = 'def foo(x): return x + 1'
+ >>> module = _ast.parse(code)
+ >>> from beniget import Ancestors
+ >>> ancestors = Ancestors()
+ >>> ancestors.visit(module)
+ >>> binop = module.body[0].body[0].value
+ >>> for n in ancestors.parents(binop):
+ ... print(str(type(n)).replace('_ast', 'ast'))
+
+
+
"""
def __init__(self):
@@ -34,7 +56,7 @@ def __init__(self):
def generic_visit(self, node):
self._parents[node] = list(self._current)
self._current.append(node)
- super(Ancestors, self).generic_visit(node)
+ super().generic_visit(node)
self._current.pop()
def parent(self, node):
@@ -51,11 +73,29 @@ def parentInstance(self, node, cls):
def parentFunction(self, node):
return self.parentInstance(node, (ast.FunctionDef,
- ast.AsyncFunctionDef))
+ ast.AsyncFunctionDef,
+ _ast.FunctionDef,
+ _ast.AsyncFunctionDef))
def parentStmt(self, node):
- return self.parentInstance(node, ast.stmt)
+ return self.parentInstance(node, _ast.stmt)
+_novalue = object()
+@contextmanager
+def _setattrs(obj, **attrs):
+ """
+ Provide cheap attribute polymorphism.
+ """
+ old_values = {}
+ for k, v in attrs.items():
+ old_values[k] = getattr(obj, k, _novalue)
+ setattr(obj, k, v)
+ yield
+ for k, v in old_values.items():
+ if v is _novalue:
+ delattr(obj, k)
+ else:
+ setattr(obj, k, v)
class Def(object):
"""
@@ -76,7 +116,7 @@ def __init__(self, node):
"""
def add_user(self, node):
- assert isinstance(node, Def)
+ assert isinstance(node, Def), node
self._users.add(node)
def name(self):
@@ -84,19 +124,25 @@ def name(self):
If the node associated to this Def has a name, returns this name.
Otherwise returns its type
"""
- if isinstance(self.node, (ast.ClassDef,
- ast.FunctionDef,
- ast.AsyncFunctionDef)):
+ typename = type(self.node).__name__
+ if typename in _HasName:
return self.node.name
- elif isinstance(self.node, ast.Name):
+ elif typename == 'Name':
return self.node.id
- elif isinstance(self.node, ast.alias):
+ elif typename == 'alias':
base = self.node.name.split(".", 1)[0]
return self.node.asname or base
+ elif typename in ('MatchStar', 'MatchAs') and self.node.name:
+ return self.node.name
+ elif typename == 'MatchMapping' and self.node.rest:
+ return self.node.rest
+ elif typename == 'arg':
+ return self.node.arg
+ elif typename == 'ExceptHandler' and self.node.name:
+ return self.node.name
elif isinstance(self.node, tuple):
return self.node[1]
- else:
- return type(self.node).__name__
+ return typename
def users(self):
"""
@@ -144,7 +190,7 @@ def collect_future_imports(node):
"""
Returns a set of future imports names for the given ast module.
"""
- assert isinstance(node, ast.Module)
+ assert type(node).__name__ == 'Module'
cf = _CollectFutureImports()
cf.visit(node)
return cf.FutureImports
@@ -185,6 +231,9 @@ def visit_Constant(self, node):
def generic_visit(self, node):
raise _StopTraversal()
+ def visit_Str(self, node):
+ pass
+
class CollectLocals(ast.NodeVisitor):
def __init__(self):
self.Locals = set()
@@ -204,7 +253,7 @@ def visit_Nonlocal(self, node):
visit_Global = visit_Nonlocal
def visit_Name(self, node):
- if isinstance(node.ctx, ast.Store) and node.id not in self.NonLocals:
+ if type(node.ctx).__name__ == 'Store' and node.id not in self.NonLocals:
self.Locals.add(node.id)
def skip(self, _):
@@ -224,16 +273,32 @@ def visit_ImportFrom(self, node):
for alias in node.names:
self.Locals.add(alias.asname or alias.name)
+class CollectLocalsdef695(CollectLocals):
+
+ visit_TypeVar = visit_ParamSpec = visit_TypeVarTuple = CollectLocals.visit_FunctionDef
+
def collect_locals(node):
'''
Compute the set of identifiers local to a given node.
This is meant to emulate a call to locals()
'''
- visitor = CollectLocals()
+ if isinstance(node, def695):
+ # workaround for the new implicit scope created by type params and co.
+ visitor = CollectLocalsdef695()
+ else:
+ visitor = CollectLocals()
visitor.generic_visit(node)
return visitor.Locals
+class def695(ast.stmt):
+ """
+ Special statement to represent the PEP-695 lexical scopes.
+ """
+ _fields = ('body', 'd')
+ def __init__(self, body, d):
+ self.body = body # list of type params
+ self.d = d # the wrapped definition node
class DefUseChains(ast.NodeVisitor):
"""
@@ -289,7 +354,9 @@ def __init__(self, filename=None):
# be defined in another path of the control flow (esp. in loop)
self._undefs = []
- # stack of nodes starting a scope: class, module, function, generator expression, comprehension...
+ # stack of nodes starting a scope:
+ # class, module, function, generator expression,
+ # comprehension, def695.
self._scopes = []
self._breaks = []
@@ -321,11 +388,11 @@ def _dump_locals(self, node, only_live=False):
for d in self.locals[node]:
if not only_live or d.islive:
groupped[d.name()].append(d)
- return ['{}:{}'.format(name, ','.join([str(getattr(d.node, 'lineno', -1)) for d in defs])) \
+ return ['{}:{}'.format(name, ','.join([str(getattr(d.node, 'lineno', None)) for d in defs])) \
for name,defs in groupped.items()]
def dump_definitions(self, node, ignore_builtins=True):
- if isinstance(node, ast.Module) and not ignore_builtins:
+ if type(node).__name__ == 'Module' and not ignore_builtins:
builtins = {d for d in self._builtins.values()}
return sorted(d.name()
for d in self.locals[node] if d not in builtins)
@@ -345,7 +412,7 @@ def location(self, node):
)
return " at {}{}:{}".format(filename,
node.lineno,
- node.col_offset)
+ getattr(node, 'col_offset', None),)
else:
return ""
@@ -375,7 +442,7 @@ def invalid_name_lookup(self, name, scope, precomputed_locals, local_defs):
# >>> foo() # fails, a is a local referenced before being assigned
# >>> class bar: a = a
# >>> bar() # ok, and `bar.a is a`
- if isinstance(scope, ast.ClassDef):
+ if type(scope).__name__ in ('ClassDef', 'def695'): # TODO: test the def695 part of this
top_level_definitions = self._definitions[0:-self._scope_depths[0]]
isglobal = any((name in top_lvl_def or '*' in top_lvl_def)
for top_lvl_def in top_level_definitions)
@@ -421,6 +488,7 @@ def compute_defs(self, node, quiet=False):
precomputed_locals = next(precomputed_locals_iter)
base_scope = next(scopes_iter)
defs = self._definitions[depth:]
+ is_def695 = isinstance(base_scope, def695)
if not self.invalid_name_lookup(name, base_scope, precomputed_locals, defs):
looked_up_definitions.extend(reversed(defs))
@@ -428,7 +496,9 @@ def compute_defs(self, node, quiet=False):
for scope, depth, precomputed_locals in zip(scopes_iter,
depths_iter,
precomputed_locals_iter):
- if not isinstance(scope, ast.ClassDef):
+ # If a def695 scope is immediately within a class scope, or within another def695 scope that is immediately within a class scope,
+ # then names defined in that class scope can be accessed within the def695 scope.
+ if type(scope).__name__ != 'ClassDef' or is_def695:
defs = self._definitions[lvl + depth: lvl]
if self.invalid_name_lookup(name, base_scope, precomputed_locals, defs):
looked_up_definitions.append(StopIteration)
@@ -462,7 +532,7 @@ def process_body(self, stmts):
deadcode = False
for stmt in stmts:
self.visit(stmt)
- if isinstance(stmt, (ast.Break, ast.Continue, ast.Raise)):
+ if type(stmt).__name__ in ('Break', 'Continue', 'Raise'):
if not deadcode:
deadcode = True
self._deadcode += 1
@@ -595,7 +665,7 @@ def set_definition(self, name, dnode_or_dnodes, index=-1):
# set the islive flag to False on killed Defs
for d in self._definitions[index].get(name, ()):
- if not isinstance(d.node, ast.AST):
+ if not isinstance(d.node, _ast.AST):
# A builtin: we never explicitely mark the builtins as killed, since
# it can be easily deducted.
continue
@@ -650,48 +720,91 @@ def visit_annotation(self, node):
self.visit(annotation)
def visit_skip_annotation(self, node):
- if isinstance(node, ast.Name):
+ if type(node).__name__ == 'Name':
self.visit_Name(node, skip_annotation=True)
else:
self.visit(node)
- def visit_FunctionDef(self, node, step=DeclarationStep):
+ def visit_FunctionDef(self, node, step=DeclarationStep, in_def695=False):
if step is DeclarationStep:
dnode = self.chains.setdefault(node, Def(node))
self.add_to_locals(node.name, dnode)
+ if not in_def695:
+
+ for kw_default in filter(None, node.args.kw_defaults):
+ self.visit(kw_default).add_user(dnode)
+ for default in node.args.defaults:
+ self.visit(default).add_user(dnode)
+ for decorator in node.decorator_list:
+ self.visit(decorator)
+
+ if any(getattr(node, 'type_params', [])):
+ self.visit_def695(def695(body=node.type_params, d=node))
+ return
+
if not self.future_annotations:
for arg in _iter_arguments(node.args):
- self.visit_annotation(arg)
+ annotation = getattr(arg, 'annotation', None)
+ if annotation:
+ if in_def695:
+ try:
+ _validate_annotation_body(annotation)
+ except SyntaxError as e :
+ self.warn(str(e), annotation)
+ continue
+ self.visit(annotation)
else:
# annotations are to be analyzed later as well
currentscopes = list(self._scopes)
if node.returns:
- self._defered_annotations[-1].append(
- (node.returns, currentscopes, None))
+ try:
+ _validate_annotation_body(node.returns)
+ except SyntaxError as e :
+ self.warn(str(e), node.returns)
+ else:
+ self._defered_annotations[-1].append(
+ (node.returns, currentscopes, None))
for arg in _iter_arguments(node.args):
if arg.annotation:
+ try:
+ _validate_annotation_body(arg.annotation)
+ except SyntaxError as e :
+ self.warn(str(e), arg.annotation)
+ continue
self._defered_annotations[-1].append(
(arg.annotation, currentscopes, None))
- for kw_default in filter(None, node.args.kw_defaults):
- self.visit(kw_default).add_user(dnode)
- for default in node.args.defaults:
- self.visit(default).add_user(dnode)
- for decorator in node.decorator_list:
- self.visit(decorator)
-
if not self.future_annotations and node.returns:
- self.visit(node.returns)
-
- self.set_definition(node.name, dnode)
+ if in_def695:
+ try:
+ _validate_annotation_body(node.returns)
+ except SyntaxError as e:
+ self.warn(str(e), node.returns)
+ else:
+ self.visit(node.returns)
+ else:
+ self.visit(node.returns)
+
+ if in_def695:
+ # emulate this (except f is not actually defined in both scopes):
+ # def695 __generic_parameters_of_f():
+ # T = TypeVar(name='T')
+ # def f(x: T) -> T:
+ # return x
+ # return f
+ # f = __generic_parameters_of_f()
+ self.set_definition(node.name, dnode, index=-2)
+ else:
+ self.set_definition(node.name, dnode)
self._defered.append((node,
list(self._definitions),
list(self._scopes),
list(self._scope_depths),
list(self._precomputed_locals)))
+
elif step is DefinitionStep:
with self.ScopeContext(node):
for arg in _iter_arguments(node.args):
@@ -702,23 +815,44 @@ def visit_FunctionDef(self, node, step=DeclarationStep):
visit_AsyncFunctionDef = visit_FunctionDef
- def visit_ClassDef(self, node):
+ def visit_ClassDef(self, node, in_def695=False):
dnode = self.chains.setdefault(node, Def(node))
self.add_to_locals(node.name, dnode)
+
+ if not in_def695:
+ for decorator in node.decorator_list:
+ self.visit(decorator).add_user(dnode)
+
+ if any(getattr(node, 'type_params', [])):
+ self.visit_def695(def695(body=node.type_params, d=node))
+ return
for base in node.bases:
+ if in_def695:
+ try:
+ _validate_annotation_body(base)
+ except SyntaxError as e:
+ self.warn(str(e), base)
+ continue
self.visit(base).add_user(dnode)
for keyword in node.keywords:
+ if in_def695:
+ try:
+ _validate_annotation_body(keyword)
+ except SyntaxError as e:
+ self.warn(str(e), keyword)
+ continue
self.visit(keyword.value).add_user(dnode)
- for decorator in node.decorator_list:
- self.visit(decorator).add_user(dnode)
with self.ScopeContext(node):
self.set_definition("__class__", Def("__class__"))
self.process_body(node.body)
- self.set_definition(node.name, dnode)
-
+ if in_def695:
+ # see comment in visit_FunctionDef
+ self.set_definition(node.name, dnode, index=-2)
+ else:
+ self.set_definition(node.name, dnode)
def visit_Return(self, node):
if node.value:
@@ -750,13 +884,18 @@ def visit_AnnAssign(self, node):
if not self.future_annotations:
self.visit(node.annotation)
else:
- self._defered_annotations[-1].append(
- (node.annotation, list(self._scopes), None))
+ try:
+ _validate_annotation_body(node.annotation)
+ except SyntaxError as e:
+ self.warn(str(e), node.annotation)
+ else:
+ self._defered_annotations[-1].append(
+ (node.annotation, list(self._scopes), None))
self.visit(node.target)
def visit_AugAssign(self, node):
dvalue = self.visit(node.value)
- if isinstance(node.target, ast.Name):
+ if type(node.target).__name__ == 'Name':
ctx, node.target.ctx = node.target.ctx, ast.Load()
dtarget = self.visit(node.target)
dvalue.add_user(dtarget)
@@ -773,6 +912,47 @@ def visit_AugAssign(self, node):
self.locals[self._scopes[-1]].append(dtarget)
else:
self.visit(node.target).add_user(dvalue)
+
+ def visit_TypeAlias(self, node, in_def695=False):
+ # Generic type aliases:
+ # type Alias[T: int] = list[T]
+
+ # Equivalent to:
+ # def695 __generic_parameters_of_Alias():
+ # def695 __evaluate_T_bound():
+ # return int
+ # T = __make_typevar_with_bound(name='T', evaluate_bound=__evaluate_T_bound)
+ # def695 __evaluate_Alias():
+ # return list[T]
+ # return __make_typealias(name='Alias', type_params=(T,), evaluate_value=__evaluate_Alias)
+ # Alias = __generic_parameters_of_Alias()
+
+ if type(node.name).__name__ == 'Name':
+ dname = self.chains.setdefault(node.name, Def(node.name))
+ self.add_to_locals(node.name.id, dname)
+
+ if not in_def695 and any(getattr(node, 'type_params', [])):
+ self.visit_def695(def695(body=node.type_params, d=node))
+ return
+
+ dnode = self.chains.setdefault(node, Def(node))
+ try:
+ _validate_annotation_body(node.value)
+ except SyntaxError as e:
+ self.warn(str(e), node.value)
+ else:
+ self._defered_annotations[-1].append(
+ (node.value, list(self._scopes), None))
+
+ if in_def695:
+ # see comment in visit_FunctionDef
+ self.set_definition(node.name.id, dname, index=-2)
+ else:
+ self.set_definition(node.name.id, dname)
+
+ return dnode
+ else:
+ raise NotImplementedError()
def visit_For(self, node):
self.visit(node.iter)
@@ -912,17 +1092,17 @@ def visit_Try(self, node):
self.extend_definition(hd, handler_def[hd])
self.process_body(node.finalbody)
-
+
def visit_Assert(self, node):
self.visit(node.test)
if node.msg:
self.visit(node.msg)
- def add_to_locals(self, name, dnode):
+ def add_to_locals(self, name, dnode, index=-1):
if any(name in _globals for _globals in self._globals):
self.set_or_extend_global(name, dnode)
- else:
- self.locals[self._scopes[-1]].append(dnode)
+ elif dnode not in self.locals[self._scopes[index]]:
+ self.locals[self._scopes[index]].append(dnode)
def visit_Import(self, node):
for alias in node.names:
@@ -946,10 +1126,16 @@ def visit_Global(self, node):
def visit_Nonlocal(self, node):
for name in node.names:
- for d in reversed(self._definitions[:-1]):
+ for i, d in enumerate(reversed(self._definitions)):
+ if i == 0:
+ continue
if name not in d:
continue
else:
+ if isinstance(self._scopes[-i-1], def695):
+ # see https://docs.python.org/3.12/reference/executionmodel.html#annotation-scopes
+ self.warn("names defined in annotation scopes cannot be rebound with nonlocal statements", node)
+ break
# this rightfully creates aliasing
self.set_definition(name, d[name])
break
@@ -959,7 +1145,100 @@ def visit_Nonlocal(self, node):
def visit_Expr(self, node):
self.generic_visit(node)
- # expr
+ # pattern matching
+
+ def visit_Match(self, node):
+
+ self.visit(node.subject)
+
+ defs = []
+ for kase in node.cases:
+ if kase.guard:
+ self.visit(kase.guard)
+ self.visit(kase.pattern)
+
+ with self.DefinitionContext(self._definitions[-1].copy()) as case_defs:
+ self.process_body(kase.body)
+ defs.append(case_defs)
+
+ if not defs:
+ return
+ if len(defs) == 1:
+ body_defs, orelse_defs, rest = defs[0], [], []
+ else:
+ body_defs, orelse_defs, rest = defs[0], defs[1], defs[2:]
+ while True:
+ # merge defs, like in if-else but repeat the process for x branches
+ for d in body_defs:
+ if d in orelse_defs:
+ self.set_definition(d, body_defs[d] + orelse_defs[d])
+ else:
+ self.extend_definition(d, body_defs[d])
+ for d in orelse_defs:
+ if d not in body_defs:
+ self.extend_definition(d, orelse_defs[d])
+ if not rest:
+ break
+ body_defs = self._definitions[-1]
+ orelse_defs, rest = rest[0], rest[1:]
+
+ def visit_MatchValue(self, node):
+ dnode = self.chains.setdefault(node, Def(node))
+ self.visit(node.value)
+ return dnode
+
+ visit_MatchSingleton = visit_MatchValue
+
+ def visit_MatchSequence(self, node):
+ # mimics a list
+ with _setattrs(node, ctx=ast.Load(), elts=node.patterns):
+ return self.visit_List(node)
+
+ def visit_MatchMapping(self, node):
+ dnode = self.chains.setdefault(node, Def(node))
+ with _setattrs(node, values=node.patterns):
+ # mimics a dict
+ self.visit_Dict(node)
+ if node.rest:
+ with _setattrs(node, id=node.rest, ctx=ast.Store(), annotation=None):
+ self.visit_Name(node)
+ return dnode
+
+ def visit_MatchClass(self, node):
+ # mimics a call
+ dnode = self.chains.setdefault(node, Def(node))
+ self.visit(node.cls).add_user(dnode)
+ for arg in node.patterns:
+ self.visit(arg).add_user(dnode)
+ for kw in node.kwd_patterns:
+ self.visit(kw).add_user(dnode)
+ return dnode
+
+ def visit_MatchStar(self, node):
+ dnode = self.chains.setdefault(node, Def(node))
+ if node.name:
+ # mimics store name
+ with _setattrs(node, id=node.name, ctx=ast.Store(), annotation=None):
+ self.visit_Name(node)
+ return dnode
+
+ def visit_MatchAs(self, node):
+ dnode = self.chains.setdefault(node, Def(node))
+ if node.pattern:
+ self.visit(node.pattern)
+ if node.name:
+ with _setattrs(node, id=node.name, ctx=ast.Store(), annotation=None):
+ self.visit_Name(node)
+ return dnode
+
+ def visit_MatchOr(self, node):
+ dnode = self.chains.setdefault(node, Def(node))
+ for pat in node.patterns:
+ self.visit(pat).add_user(dnode)
+ return dnode
+
+ # expressions
+
def visit_BoolOp(self, node):
dnode = self.chains.setdefault(node, Def(node))
for value in node.values:
@@ -1112,7 +1391,7 @@ def visit_Subscript(self, node):
return dnode
def visit_Starred(self, node):
- if isinstance(node.ctx, ast.Store):
+ if type(node.ctx).__name__ == 'Store':
return self.visit(node.value)
else:
dnode = self.chains.setdefault(node, Def(node))
@@ -1122,25 +1401,21 @@ def visit_Starred(self, node):
def visit_NamedExpr(self, node):
dnode = self.chains.setdefault(node, Def(node))
self.visit(node.value).add_user(dnode)
- if isinstance(node.target, ast.Name):
+ if type(node.target).__name__ == 'Name':
self.visit_Name(node.target, named_expr=True)
return dnode
- def is_in_current_scope(self, name):
- return any(name in defs
- for defs in self._definitions[self._scope_depths[-1]:])
-
def _first_non_comprehension_scope(self):
index = -1
enclosing_scope = self._scopes[index]
- while isinstance(enclosing_scope, (ast.DictComp, ast.ListComp,
- ast.SetComp, ast.GeneratorExp)):
+ while type(enclosing_scope).__name__ in _Comp:
index -= 1
enclosing_scope = self._scopes[index]
return index, enclosing_scope
def visit_Name(self, node, skip_annotation=False, named_expr=False):
- if isinstance(node.ctx, (ast.Param, ast.Store)):
+ ctx_typename = type(node.ctx).__name__
+ if ctx_typename in ('Param', 'Store'):
dnode = self.chains.setdefault(node, Def(node))
# FIXME: find a smart way to merge the code below with add_to_locals
if any(node.id in _globals for _globals in self._globals):
@@ -1151,7 +1426,7 @@ def visit_Name(self, node, skip_annotation=False, named_expr=False):
index, enclosing_scope = (self._first_non_comprehension_scope()
if named_expr else (-1, self._scopes[-1]))
- if index < -1 and isinstance(enclosing_scope, ast.ClassDef):
+ if index < -1 and type(enclosing_scope).__name__ == 'ClassDef':
# invalid named expression, not calling set_definition.
self.warn('assignment expression within a comprehension '
'cannot be used in a class body', node)
@@ -1162,11 +1437,10 @@ def visit_Name(self, node, skip_annotation=False, named_expr=False):
self.locals[self._scopes[index]].append(dnode)
# Name.annotation is a special case because of gast
- if node.annotation is not None and not skip_annotation and not self.future_annotations:
+ if getattr(node, 'annotation', None) is not None and not skip_annotation and not self.future_annotations:
self.visit(node.annotation)
-
- elif isinstance(node.ctx, (ast.Load, ast.Del)):
+ elif ctx_typename in ('Load', 'Del'):
node_in_chains = node in self.chains
if node_in_chains:
dnode = self.chains[node]
@@ -1185,26 +1459,29 @@ def visit_Destructured(self, node):
dnode = self.chains.setdefault(node, Def(node))
tmp_store = ast.Store()
for elt in node.elts:
- if isinstance(elt, ast.Name):
+ elt_typename = type(elt).__name__
+ if elt_typename == 'Name':
tmp_store, elt.ctx = elt.ctx, tmp_store
self.visit(elt)
tmp_store, elt.ctx = elt.ctx, tmp_store
- elif isinstance(elt, (ast.Subscript, ast.Starred, ast.Attribute)):
+ elif elt_typename in ('Subscript', 'Starred', 'Attribute'):
self.visit(elt)
- elif isinstance(elt, (ast.List, ast.Tuple)):
+ elif elt_typename in ('List', 'Tuple'):
self.visit_Destructured(elt)
return dnode
def visit_List(self, node):
- if isinstance(node.ctx, ast.Load):
+ if type(node.ctx).__name__ == 'Load':
dnode = self.chains.setdefault(node, Def(node))
for elt in node.elts:
self.visit(elt).add_user(dnode)
return dnode
# unfortunately, destructured node are marked as Load,
# only the parent List/Tuple is marked as Store
- elif isinstance(node.ctx, ast.Store):
+ elif type(node.ctx).__name__ == 'Store':
return self.visit_Destructured(node)
+ else:
+ raise NotImplementedError()
visit_Tuple = visit_List
@@ -1219,6 +1496,49 @@ def visit_Slice(self, node):
if node.step:
self.visit(node.step).add_user(dnode)
return dnode
+
+ # type params
+
+ def visit_def695(self, node):
+ # We don't use two steps here because the declaration
+ # step is the same as definition step for def695's
+ # 1.type parameters of generic type aliases,
+ # 2.type parameters and annotations of generic functions and
+ # 3.type parameters and base class expressions of generic classes
+ # the rest is evaluated as defered annotations:
+ # 4.the value of generic type aliases
+ # 5.the bounds of type variables
+ # 6.the constraints of type variables
+
+ # introduce the new scope
+ dnode = self.chains.setdefault(node.d, Def(node.d))
+
+ with self.ScopeContext(node):
+ # visit the type params
+ for p in node.body:
+ try:
+ _validate_annotation_body(p)
+ except SyntaxError as e:
+ self.warn(str(e), p)
+ else:
+ self.visit(p).add_user(dnode)
+ # then visit the actual node while
+ # being in the def695 scope.
+ visitor = getattr(self, "visit_{}".format(type(node.d).__name__))
+ visitor(node.d, in_def695=True)
+
+ def visit_TypeVar(self, node):
+ dnode = self.chains.setdefault(node, Def(node))
+ self.set_definition(node.name, dnode)
+ self.add_to_locals(node.name, dnode)
+
+ if type(node).__name__ == 'TypeVar' and node.bound:
+ self._defered_annotations[-1].append(
+ (node.bound, list(self._scopes), None))
+
+ return dnode
+
+ visit_ParamSpec = visit_TypeVarTuple = visit_TypeVar
# misc
@@ -1248,9 +1568,7 @@ def visit_excepthandler(self, node):
self.process_body(node.body)
return dnode
- def visit_arguments(self, node):
- for arg in _iter_arguments(node):
- self.visit(arg)
+ # visit_arguments is not implemented on purpose
def visit_withitem(self, node):
dnode = self.chains.setdefault(node, Def(node))
@@ -1267,24 +1585,43 @@ def _validate_comprehension(node):
"""
iter_names = set() # comprehension iteration variables
for gen in node.generators:
- for namedexpr in (n for n in ast.walk(gen.iter) if isinstance(n, ast.NamedExpr)):
+ for namedexpr in (n for n in ast.walk(gen.iter) if type(n).__name__ == 'NamedExpr'):
raise SyntaxError('assignment expression cannot be used '
'in a comprehension iterable expression')
iter_names.update(n.id for n in ast.walk(gen.target)
- if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Store))
- for namedexpr in (n for n in ast.walk(node) if isinstance(n, ast.NamedExpr)):
+ if type(n).__name__ == 'Name' and type(n.ctx).__name__ == 'Store')
+ for namedexpr in (n for n in ast.walk(node) if type(n).__name__ == 'NamedExpr'):
bound = getattr(namedexpr.target, 'id', None)
if bound in iter_names:
raise SyntaxError('assignment expression cannot rebind '
"comprehension iteration variable '{}'".format(bound))
+_node_type_to_human_name = {
+ 'NamedExpr': 'assignment expression',
+ 'Yield': 'yield keyword',
+ 'YieldFrom': 'yield keyword',
+ 'Await': 'await keyword'
+}
+
+def _validate_annotation_body(node):
+ """
+ Raises SyntaxError if:
+ - the warlus operator is used
+ - the yield/ yield from statement is used
+ - the await keyword is used
+ """
+ for illegal in (n for n in ast.walk(node) if type(n).__name__ in
+ ('NamedExpr', 'Yield', 'YieldFrom', 'Await')):
+ name = _node_type_to_human_name.get(type(illegal).__name__, 'current syntax')
+ raise SyntaxError(f'{name} cannot be used in annotation-like scopes')
+
def _iter_arguments(args):
"""
Yields all arguments of the given ast.arguments instance.
"""
for arg in args.args:
yield arg
- for arg in args.posonlyargs:
+ for arg in getattr(args, 'posonlyargs', ()):
yield arg
if args.vararg:
yield args.vararg
@@ -1345,7 +1682,15 @@ def lookup_annotation_name_defs(name, heads, locals_map):
try:
return _lookup(name, scopes, locals_map)
except LookupError:
- raise LookupError("'{}' not found in {}, might be a builtin".format(name, heads[-1]))
+ if name in BuiltinsSrc:
+ raise LookupError(f'{name} is a builtin')
+ try:
+ _lookup(name, scopes, locals_map, only_live=False)
+ except LookupError:
+ defined_names = [d.name() for s in scopes for d in locals_map[s]]
+ raise LookupError("'{}' not found in scopes: {} (heads={}) (available names={})".format(name, scopes, heads, defined_names))
+ else:
+ raise LookupError("'{}' is killed".format(name))
def _get_lookup_scopes(heads):
# heads[-1] is the direct enclosing scope and heads[0] is the module.
@@ -1355,33 +1700,33 @@ def _get_lookup_scopes(heads):
heads = list(heads) # avoid modifying the list (important)
try:
- direct_scope = heads.pop(-1) # this scope is the only one that can be a class
+ direct_scopes = [heads.pop(-1)] # this scope is the only one that can be a class, expect in case of the presence of def695
except IndexError:
raise ValueError('invalid heads: must include at least one element')
try:
global_scope = heads.pop(0)
except IndexError:
# we got only a global scope
- return [direct_scope]
+ return direct_scopes
+ else:
+ if heads and isinstance(direct_scopes[-1], def695) and type(heads[-1]).__name__ == 'ClassDef':
+ direct_scopes.insert(0, heads.pop(-1))
# more of less modeling what's described here.
# https://github.com/gvanrossum/gvanrossum.github.io/blob/main/formal/scopesblog.md
- other_scopes = [s for s in heads if isinstance(s, (
- ast.FunctionDef, ast.AsyncFunctionDef,
- ast.Lambda, ast.DictComp, ast.ListComp,
- ast.SetComp, ast.GeneratorExp))]
- return [global_scope] + other_scopes + [direct_scope]
-
-def _lookup(name, scopes, locals_map):
- context = scopes.pop()
+ other_scopes = [s for s in heads if type(s).__name__ in _ClosedScopes]
+ return [global_scope] + other_scopes + direct_scopes
+
+def _lookup(name, scopes, locals_map, only_live=True):
+ context = scopes[-1]
defs = []
for loc in locals_map.get(context, ()):
- if loc.name() == name and loc.islive:
+ if loc.name() == name and (loc.islive if only_live else True):
defs.append(loc)
if defs:
return defs
- elif len(scopes)==0:
+ elif len(scopes)==1:
raise LookupError()
- return _lookup(name, scopes, locals_map)
+ return _lookup(name, scopes[:-1], locals_map)
class UseDefChains(object):
"""
@@ -1394,7 +1739,7 @@ class UseDefChains(object):
def __init__(self, defuses):
self.chains = {}
for chain in defuses.chains.values():
- if isinstance(chain.node, ast.Name):
+ if type(chain.node).__name__ == 'Name':
self.chains.setdefault(chain.node, [])
for use in chain.users():
self.chains.setdefault(use.node, []).append(chain)
diff --git a/beniget/standard.py b/beniget/standard.py
new file mode 100644
index 0000000..81b480b
--- /dev/null
+++ b/beniget/standard.py
@@ -0,0 +1,61 @@
+"""
+This module offers the same three analyses, but designed to be run on standard library nodes.
+"""
+
+import ast
+from beniget import beniget, Ancestors
+
+__all__ = ('Ancestors', 'Def', 'DefUseChains', 'UseDefChains')
+
+class DefUseChains(beniget.DefUseChains):
+
+ def visit_skip_annotation(self, node):
+ if isinstance(node, ast.arg):
+ return self.visit_arg(node, skip_annotation=True)
+ return super().visit_skip_annotation(node)
+
+ def visit_ExceptHandler(self, node):
+ if isinstance(node.name, str):
+ # standard library nodes does not wrap
+ # the exception 'as' name in Name instance, so we use
+ # the ExceptHandler instance as reference point.
+ dnode = self.chains.setdefault(node, beniget.Def(node))
+ self.set_definition(node.name, dnode)
+ if dnode not in self.locals[self._scopes[-1]]:
+ self.locals[self._scopes[-1]].append(dnode)
+ self.generic_visit(node)
+
+ def visit_arg(self, node, skip_annotation=False):
+ dnode = self.chains.setdefault(node, beniget.Def(node))
+ self.set_definition(node.arg, dnode)
+ if dnode not in self.locals[self._scopes[-1]]:
+ self.locals[self._scopes[-1]].append(dnode)
+ if node.annotation is not None and not skip_annotation:
+ self.visit(node.annotation)
+ return dnode
+
+ def visit_ExtSlice(self, node):
+ dnode = self.chains.setdefault(node, beniget.Def(node))
+ for elt in node.dims:
+ self.visit(elt).add_user(dnode)
+ return dnode
+
+ def visit_Index(self, node):
+ # pretend Index does not exist
+ return self.visit(node.value)
+
+ visit_NameConstant = visit_Num = visit_Str = \
+ visit_Bytes = visit_Ellipsis = visit_Constant = beniget.DefUseChains.visit_Constant
+
+class UseDefChains(beniget.UseDefChains):
+ def __init__(self, defuses):
+ self.chains = {}
+ for chain in defuses.chains.values():
+ if isinstance(chain.node, (ast.Name, ast.arg)): # the only change is here
+ self.chains.setdefault(chain.node, [])
+ for use in chain.users():
+ self.chains.setdefault(use.node, []).append(chain)
+
+ for chain in defuses._builtins.values():
+ for use in chain.users():
+ self.chains.setdefault(use.node, []).append(chain)
diff --git a/requirements.txt b/requirements.txt
index 6c43616..d02c8e1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1 +1 @@
-gast ~= 0.5.0
+gast ~= 0.5.0
\ No newline at end of file
diff --git a/tests/__init__.py b/tests/__init__.py
index f1b4217..a0cf49a 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1,3 +1,6 @@
+"""
+Each TestCase subclass should have it's standard library counterpart.
+"""
import tests.test_definitions
import tests.test_chains
import tests.test_capture
diff --git a/tests/test_attributes.py b/tests/test_attributes.py
index 3aa7eb3..be079f5 100644
--- a/tests/test_attributes.py
+++ b/tests/test_attributes.py
@@ -1,19 +1,20 @@
from unittest import TestCase
from textwrap import dedent
-import gast as ast
-import beniget
+import ast as _ast
+import gast as _gast
+from .test_chains import getDefUseChainsType
-class Attributes(ast.NodeVisitor):
+class Attributes(_ast.NodeVisitor):
def __init__(self, module_node):
- self.chains = beniget.DefUseChains()
+ self.chains = getDefUseChainsType(module_node)()
self.chains.visit(module_node)
self.attributes = set()
self.users = set()
def visit_ClassDef(self, node):
for stmt in node.body:
- if isinstance(stmt, ast.FunctionDef):
+ if isinstance(stmt, (_ast.FunctionDef, _gast.FunctionDef)):
self_def = self.chains.chains[stmt.args.args[0]]
self.users.update(use.node for use in self_def.users())
self.generic_visit(node)
@@ -24,8 +25,9 @@ def visit_Attribute(self, node):
class TestAttributes(TestCase):
+ ast = _gast
def checkAttribute(self, code, extract, ref):
- module = ast.parse(dedent(code))
+ module = self.ast.parse(dedent(code))
c = Attributes(module)
c.visit(extract(module))
self.assertEqual(c.attributes, ref)
@@ -97,3 +99,6 @@ def bar(self, other):
self = list
return self.pop"""
self.checkAttribute(code, lambda n: n.body[0], set())
+
+class TestAttributesStdlib(TestAttributes):
+ ast = _ast
\ No newline at end of file
diff --git a/tests/test_capture.py b/tests/test_capture.py
index ccd472a..c676a90 100644
--- a/tests/test_capture.py
+++ b/tests/test_capture.py
@@ -1,12 +1,14 @@
from unittest import TestCase
from textwrap import dedent
-import gast as ast
-import beniget
+import ast as _ast
+import gast as _gast
+from .test_chains import getDefUseChainsType
-class Capture(ast.NodeVisitor):
+
+class Capture(_ast.NodeVisitor):
def __init__(self, module_node):
- self.chains = beniget.DefUseChains()
+ self.chains = getDefUseChainsType(module_node)()
self.chains.visit(module_node)
self.users = set()
self.captured = set()
@@ -17,15 +19,16 @@ def visit_FunctionDef(self, node):
self.generic_visit(node)
def visit_Name(self, node):
- if isinstance(node.ctx, ast.Load):
+ if isinstance(node.ctx, (_ast.Load, _gast.Load)):
if node not in self.users:
# FIXME: IRL, should be the definition of this use
self.captured.add(node.id)
class TestCapture(TestCase):
+ ast = _gast
def checkCapture(self, code, extract, ref):
- module = ast.parse(dedent(code))
+ module = self.ast.parse(dedent(code))
c = Capture(module)
c.visit(extract(module))
self.assertEqual(c.captured, ref)
@@ -43,3 +46,6 @@ def foo(x):
def bar(x):
return x"""
self.checkCapture(code, lambda n: n.body[0].body[0], set())
+
+class TestCaptureStdlib(TestCapture):
+ ast = _ast
\ No newline at end of file
diff --git a/tests/test_chains.py b/tests/test_chains.py
index a388867..08633a7 100644
--- a/tests/test_chains.py
+++ b/tests/test_chains.py
@@ -1,14 +1,27 @@
from contextlib import contextmanager
from unittest import TestCase, skipIf
import unittest
-import gast as ast
-import beniget
+import beniget.standard
import io
import sys
+import ast as _ast
+import gast as _gast
+import gast.gast as _gast_module
# Show full diff in unittest
unittest.util._MAX_LENGTH=2000
+def replace_deprecated_names(out):
+ return out.replace(
+ 'Num', 'Constant'
+ ).replace(
+ 'Ellipsis', 'Constant'
+ ).replace(
+ 'Str', 'Constant'
+ ).replace(
+ 'Bytes', 'Constant'
+ )
+
@contextmanager
def captured_output():
new_out, new_err = io.StringIO(), io.StringIO()
@@ -19,25 +32,53 @@ def captured_output():
finally:
sys.stdout, sys.stderr = old_out, old_err
+gast_nodes = tuple(getattr(_gast, t[0]) for t in _gast_module._nodes)
-class TestDefUseChains(TestCase):
- def checkChains(self, code, ref, strict=True):
- class StrictDefUseChains(beniget.DefUseChains):
+def getDefUseChainsType(node):
+ if isinstance(node, gast_nodes):
+ return beniget.DefUseChains
+ return beniget.standard.DefUseChains
+
+def getStrictDefUseChains(node):
+ class StrictDefUseChains(getDefUseChainsType(node)):
def warn(self, msg, node):
raise RuntimeError(
"W: {} at {}:{}".format(
msg, node.lineno, node.col_offset
)
)
+ return StrictDefUseChains
+
+def getUseDefChainsType(node):
+ if isinstance(node, gast_nodes):
+ return beniget.UseDefChains
+ return beniget.standard.UseDefChains
- node = ast.parse(code)
+class TestDefUseChains(TestCase):
+ ast = _gast
+ maxDiff = None
+ def checkChains(self, code, ref, strict=True):
+ node = self.ast.parse(code)
if strict:
- c = StrictDefUseChains()
+ c = getStrictDefUseChains(node)()
else:
- c = beniget.DefUseChains()
+ c = getDefUseChainsType(node)()
+
c.visit(node)
- self.assertEqual(c.dump_chains(node), ref)
+ self.assertEqual(c.dump_chains(node), ref)
return node, c
+
+ def checkUseDefChains(self, code, ref, strict=True):
+ node = self.ast.parse(code)
+ if strict:
+ c = getStrictDefUseChains(node)()
+ else:
+ c = getDefUseChainsType(node)()
+
+ c.visit(node)
+ cc = getUseDefChainsType(node)(c)
+
+ self.assertEqual(str(cc), ref)
def test_simple_expression(self):
code = "a = 1; a + 2"
@@ -398,8 +439,8 @@ def test_class_base(self):
def test_def_used_in_self_default(self):
code = "def foo(x:foo): return foo"
- c = beniget.DefUseChains()
- node = ast.parse(code)
+ node = self.ast.parse(code)
+ c = getDefUseChainsType(node)()
c.visit(node)
self.assertEqual(c.dump_chains(node), ["foo -> (foo -> ())"])
@@ -411,36 +452,36 @@ class mytype(str):
x = x+1 # <- this triggers NameError: name 'x' is not defined
return x
'''
- c = beniget.DefUseChains()
- node = ast.parse(code)
+ node = self.ast.parse(code)
+ c = getDefUseChainsType(node)()
c.visit(node)
self.assertEqual(c.dump_chains(node.body[0]), ['x -> (x -> ())', 'mytype -> ()'])
def test_unbound_class_variable2(self):
code = '''class A:\n a = 10\n def f(self):\n return a # a is not defined'''
- c = beniget.DefUseChains()
- node = ast.parse(code)
+ node = self.ast.parse(code)
+ c = getDefUseChainsType(node)()
c.visit(node)
self.assertEqual(c.dump_chains(node.body[0]), ['a -> ()', 'f -> ()'])
def test_unbound_class_variable3(self):
code = '''class A:\n a = 10\n class I:\n b = a + 1 # a is not defined'''
- c = beniget.DefUseChains()
- node = ast.parse(code)
+ node = self.ast.parse(code)
+ c = getDefUseChainsType(node)()
c.visit(node)
self.assertEqual(c.dump_chains(node.body[0]), ['a -> ()', 'I -> ()'])
def test_unbound_class_variable4(self):
code = '''class A:\n a = 10\n f = lambda: a # a is not defined'''
- c = beniget.DefUseChains()
- node = ast.parse(code)
+ node = self.ast.parse(code)
+ c = getDefUseChainsType(node)()
c.visit(node)
self.assertEqual(c.dump_chains(node.body[0]), ['a -> ()', 'f -> ()'])
def test_unbound_class_variable5(self):
code = '''class A:\n a = 10\n b = [a for _ in range(10)] # a is not defined'''
- c = beniget.DefUseChains()
- node = ast.parse(code)
+ node = self.ast.parse(code)
+ c = getDefUseChainsType(node)()
c.visit(node)
self.assertEqual(c.dump_chains(node.body[0]), ['a -> ()', 'b -> ()'])
@@ -463,14 +504,14 @@ def count(self) -> mytype: # this should trigger unbound identifier
def c(x) -> mytype(): # this one shouldn't
...
'''
- c = beniget.DefUseChains()
- node = ast.parse(code)
+ node = self.ast.parse(code)
+ c = getDefUseChainsType(node)()
c.visit(node)
self.assertEqual(c.dump_chains(node.body[0].body[0]), ['mytype -> (mytype -> (Call -> ()))'])
def check_message(self, code, expected_messages, filename=None):
- node = ast.parse(code)
- c = beniget.DefUseChains(filename)
+ node = self.ast.parse(code)
+ c = getDefUseChainsType(node)(filename)
with captured_output() as (out, err):
c.visit(node)
@@ -524,8 +565,8 @@ class Visitor:
def visit_Name(self, node):pass
visit_Attribute = visit_Name
'''
- node = ast.parse(code)
- c = beniget.DefUseChains()
+ node = self.ast.parse(code)
+ c = getDefUseChainsType(node)()
c.visit(node)
self.assertEqual(c.dump_chains(node), ['visit_Name -> ()',
'Visitor -> ()'])
@@ -540,8 +581,8 @@ class Attr(object):pass
class Visitor:
class Attr(Attr):pass
'''
- node = ast.parse(code)
- c = beniget.DefUseChains()
+ node = self.ast.parse(code)
+ c = getDefUseChainsType(node)()
c.visit(node)
self.assertEqual(c.dump_chains(node),
['Attr -> (Attr -> (Attr -> ()))',
@@ -593,8 +634,8 @@ class Visitor:
def f(): return f()
def visit_Name(self, node:Thing, fn:f):...
'''
- node = ast.parse(code)
- c = beniget.DefUseChains()
+ node = self.ast.parse(code)
+ c = getDefUseChainsType(node)()
c.visit(node)
self.assertEqual(c.dump_chains(node),
['Thing -> ()',
@@ -611,8 +652,8 @@ def visit_Attribute(self, node):pass
class Visitor:
visit_Attribute = visit_Attribute
'''
- node = ast.parse(code)
- c = beniget.DefUseChains()
+ node = self.ast.parse(code)
+ c = getDefUseChainsType(node)()
c.visit(node)
self.assertEqual(c.dump_chains(node),
['visit_Attribute -> (visit_Attribute -> ())',
@@ -919,7 +960,7 @@ def Thing(self, y:Type[Thing]) -> Thing: # this is OK, and it links to the top l
"W: unbound identifier 'D'",
]
- assert len(produced_messages) == len(expected_warnings), len(produced_messages)
+ assert len(produced_messages) == len(expected_warnings), produced_messages
assert all(any(w in pw for pw in produced_messages) for w in expected_warnings)
# locals of C
@@ -1064,8 +1105,8 @@ class mytype2(int):
fn = outer()
'''
- mod = ast.parse(code)
- chains = beniget.DefUseChains('test')
+ mod = self.ast.parse(code)
+ chains = getDefUseChainsType(mod)('test')
with captured_output() as (out, err):
chains.visit(mod)
@@ -1110,8 +1151,8 @@ class mytype2(int):
# to the inner classes.
def test_lookup_scopes(self):
- from beniget.beniget import _get_lookup_scopes
- mod, fn, cls, lambd, gen, comp = ast.Module(), ast.FunctionDef(), ast.ClassDef(), ast.Lambda(), ast.GeneratorExp(), ast.DictComp()
+ from beniget.beniget import _get_lookup_scopes, def695
+ mod, fn, cls, lambd, gen, comp, typeparams = self.ast.Module(), self.ast.FunctionDef(), self.ast.ClassDef(), self.ast.Lambda(), self.ast.GeneratorExp(), self.ast.DictComp(), def695(body=[], d=self.ast.FunctionDef())
assert _get_lookup_scopes((mod, fn, fn, fn, cls)) == [mod, fn, fn, fn, cls]
assert _get_lookup_scopes((mod, fn, fn, fn, cls, fn)) == [mod, fn, fn, fn, fn]
assert _get_lookup_scopes((mod, cls, fn)) == [mod, fn]
@@ -1121,6 +1162,11 @@ def test_lookup_scopes(self):
assert _get_lookup_scopes((mod, fn)) == [mod, fn]
assert _get_lookup_scopes((mod, cls)) == [mod, cls]
assert _get_lookup_scopes((mod,)) == [mod]
+ assert _get_lookup_scopes((mod, typeparams)) == [mod, typeparams]
+ assert _get_lookup_scopes((mod, typeparams, typeparams)) == [mod, typeparams, typeparams]
+ assert _get_lookup_scopes((mod, cls, typeparams)) == [mod, cls, typeparams]
+ assert _get_lookup_scopes((mod, cls, cls, typeparams)) == [mod, cls, typeparams]
+ assert _get_lookup_scopes((mod, cls, cls, typeparams, fn)) == [mod, typeparams, fn]
with self.assertRaises(ValueError, msg='invalid heads: must include at least one element'):
_get_lookup_scopes(())
@@ -1237,28 +1283,500 @@ class A:
strict=False
)
- @skipIf(sys.version_info.major < 3, "Python 3 syntax")
def test_annotation_def_is_not_assign_target(self):
code = 'from typing import Optional; var:Optional'
self.checkChains(code, ['Optional -> (Optional -> ())',
'var -> ()'])
+
+ def test_pep563_disallowed_expressions(self):
+ cases = [
+ "def func(a: (yield)) -> ...: ...",
+ "def func(a: ...) -> (yield from []): ...",
+ "def func(*a: (y := 3)) -> ...: ...",
+ "def func(**a: (await 42)) -> ...: ...",
+
+ "x: (yield) = True",
+ "x: (yield from []) = True",
+ "x: (y := 3) = True",
+ "x: (await 42) = True",]
+
+ for code in cases:
+ code = f'from __future__ import annotations\n' + code
+ with self.subTest(code):
+ self.check_message(code, ['cannot be used in annotation-like scopes'])
-class TestUseDefChains(TestCase):
- def checkChains(self, code, ref):
- class StrictDefUseChains(beniget.DefUseChains):
- def unbound_identifier(self, name, node):
- raise RuntimeError(
- "W: unbound identifier '{}' at {}:{}".format(
- name, node.lineno, node.col_offset
- )
- )
+ for code in cases:
+ with self.subTest(code):
+ # From python 3.13, this should generate the same error.
+ self.check_message(code, [])
+
+ # PEP-695 test cases taken from https://github.com/python/cpython/pull/103764/files
+ # but also https://github.com/python/cpython/pull/109297/files and
+ # https://github.com/python/cpython/pull/109123/files
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_typeparams_name_collision_01(self):
+ # The following code triggers syntax error at runtime.
+ # But detecting this would required beniget to keep track of the
+ # names of type parameters and validate them like we validate comprehensions or annotations.
+ # We don't do it for functions currently, so it doesn't make sens to do it for
+ # type parameters at this time.
+ code = """def func[**A, A](): ..."""
+ self.checkChains(code, ['func -> ()'])
+ self.checkUseDefChains(code, 'func <- {A, A}')
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_typeparams_name_non_collision_02(self):
+ code = """def func[A](A): return A"""
+ self.checkChains(code, ['func -> ()'])
+ self.checkUseDefChains(code, 'A <- {A}, A <- {}, func <- {A}')
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_typeparams_name_non_collision_03(self):
+ code = """def func[A](*A): return A"""
+ self.checkChains(code, ['func -> ()'])
+ self.checkUseDefChains(code, 'A <- {A}, A <- {}, func <- {A}')
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_typeparams_name_non_collision_04(self):
+ # Mangled names should not cause a conflict.
+ code = """class ClassA:\n def func[__A](self, __A): return __A"""
+ self.checkChains(code, ['ClassA -> ()'])
+ self.checkUseDefChains(code, '__A <- {__A}, __A <- {}, func <- {__A}, self <- {}')
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_typeparams_name_non_collision_05(self):
+ code = """class ClassA:\n def func[_ClassA__A](self, __A): return __A"""
+ self.checkChains(code, ['ClassA -> ()'])
+ self.checkUseDefChains(code, '__A <- {__A}, __A <- {}, func <- {_ClassA__A}, self <- {}')
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_typeparams_name_non_collision_06(self):
+ code = """class ClassA[X]:\n def func(self, X): return X"""
+ self.checkChains(code, ['ClassA -> ()'])
+ self.checkUseDefChains(code, 'ClassA <- {X}, X <- {X}, X <- {}, self <- {}')
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_typeparams_name_non_collision_07(self):
+ code = """class ClassA[X]:\n def func(self):\n X = 1;return X"""
+ self.checkChains(code, ['ClassA -> ()'])
+ self.checkUseDefChains(code, 'ClassA <- {X}, X <- {X}, X <- {}, self <- {}')
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_typeparams_name_non_collision_08(self):
+ code = """class ClassA[X]:\n def func(self): return [X for X in [1, 2]]"""
+ self.checkChains(code, ['ClassA -> ()'])
+ self.checkUseDefChains(code, 'ClassA <- {X}, List <- {Constant, Constant}, ListComp <- {X, comprehension}, X <- {X}, X <- {}, comprehension <- {List}, self <- {}')
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_typeparams_name_non_collision_09(self):
+ code = """class ClassA[X]:\n def func[X](self):..."""
+ self.checkChains(code, ['ClassA -> ()'])
+ self.checkUseDefChains(code, 'ClassA <- {X}, func <- {X}, self <- {}')
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_typeparams_name_non_collision_10(self):
+ code = """class ClassA[X]:\n X: int"""
+ self.checkChains(code, ['ClassA -> ()'])
+ self.checkUseDefChains(code, 'ClassA <- {X}, X <- {}, int <- {type}')
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_typeparams_name_non_collision_13(self):
+ code = """X = 1\ndef outer():\n def inner[X]():\n global X;X=2\n return inner"""
+ node, chains = self.checkChains(code, ['X -> ()', 'outer -> ()'])
+ self.assertEqual(chains.dump_chains(node.body[-1]), ['inner -> (inner -> ())'])
+ self.checkUseDefChains(code, 'X <- {}, X <- {}, inner <- {X}, inner <- {inner}')
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_typeparams_disallowed_expressions(self):
+ cases = ["type X = (yield)",
+ "type X = (yield from x)",
+ "type X = (await 42)",
+ "async def f(): type X = (yield)",
+ "type X = (y := 3)",
+ "class X[T: (yield)]: pass",
+ "class X[T: (yield from [])]: pass",
+ "class X[T: (await 42)]: pass",
+ "class X[T: (y := 3)]: pass",
+ "class X[T](y := list[T]): pass",
+ "def f[T](y: (x := list[T])): pass",]
+
+ for code in cases:
+ with self.subTest(code):
+ self.check_message(code, ['cannot be used in annotation-like scopes'])
+
+ for code in cases:
+ code = f'from __future__ import annotations\n' + code
+ with self.subTest(code):
+ self.check_message(code, ['cannot be used in annotation-like scopes'])
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_type_alias_name_collision_01(self):
+ # syntax error at runtime "duplicate type parameter 'A'"
+ code = """type TA1[A, **A] = None"""
+ self.checkChains(code, ['TA1 -> ()'])
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_type_alias_name_non_collision_02(self):
+ code = """type TA1[A] = lambda A: A"""
+ self.checkChains(code, ['TA1 -> ()'])
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_type_alias_name_non_collision_03(self):
+ code = """class Outer[A]:\n type TA1[A] = None"""
+ self.checkChains(code, ['Outer -> ()'])
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_type_alias_access_01(self):
+ code = "type TA1[A, B] = dict[A, B]"
+ self.checkChains(code, ['TA1 -> ()'])
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_type_alias_access_02(self):
+ code = """type TA1[A, B] = TA1[A, B] | int"""
+ self.checkChains(code, ['TA1 -> (TA1 -> (Subscript -> (BinOp -> ())))'])
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_type_alias_access_03(self):
+ code = """class Outer[A]:\n def inner[B](self):\n type TA1[C] = TA1[A, B] | int; return TA1"""
+ self.checkChains(code, ['Outer -> ()'])
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_scopes01(self):
+ code = """\
+from typing import Sequence
+
+# The following generates no compiler error, but a type checker
+# should generate an error because an upper bound type must be concrete,
+# and ``Sequence[S]`` is generic. Future extensions to the type system may
+# eliminate this limitation.
+class ClassA[S, T: Sequence[S]]: ...
+
+# The following generates no compiler error, because the bound for ``S``
+# is lazily evaluated. However, type checkers should generate an error.
+class ClassB[S: Sequence[T], T]: ...
+"""
+ self.checkChains(code, ['Sequence -> (Sequence -> (Subscript -> ()), Sequence -> (Subscript -> ()))',
+ 'ClassA -> ()',
+ 'ClassB -> ()'])
+ self.checkUseDefChains(code, 'ClassA <- {S, T}, ClassB <- {S, T}, S <- {S}, '
+ 'Sequence <- {Sequence}, Sequence <- {Sequence}, '
+ 'Subscript <- {S, Sequence}, Subscript <- {Sequence, T}, T <- {T}')
- node = ast.parse(code)
- c = StrictDefUseChains()
- c.visit(node)
- cc = beniget.UseDefChains(c)
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_scopes02(self):
+ code = """\
+from x import BaseClass, dec, Foo
- self.assertEqual(str(cc), ref)
+class ClassA[T](BaseClass[T], param = Foo[T]): ... # OK
+
+print(T) # Runtime error: 'T' is not defined
+
+@dec(Foo[T]) # Runtime error: 'T' is not defined
+class ClassA[T]: ...
+"""
+ self.check_message(code, ["W: unbound identifier 'T' at :5:6",
+ "W: unbound identifier 'T' at :7:9"])
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_scopes03(self):
+ code = """\
+from x import dec
+def func1[T](a: T) -> T: ... # OK
+
+print(T) # Runtime error: 'T' is not defined
+
+def func2[T](a = list[T]): ... # Runtime error: 'T' is not defined
+
+@dec(list[T]) # Runtime error: 'T' is not defined
+def func3[T](): ...
+"""
+ self.check_message(code, ["W: unbound identifier 'T' at :4:6",
+ "W: unbound identifier 'T' at :6:22",
+ "W: unbound identifier 'T' at :8:10"])
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_scopes04(self):
+
+ code = """\
+S = 0
+
+def outer1[S]():
+ S = 1
+ T = 1
+
+ def outer2[T]():
+
+ def inner1():
+ nonlocal S # OK because it binds variable S from outer1
+ print(S)
+ nonlocal T # Syntax error: nonlocal binding not allowed for type parameter
+ print(T)
+
+ def inner2():
+ global S # OK because it binds variable S from global scope
+ print(S)
+"""
+ self.check_message(code, ['W: names defined in annotation scopes cannot be rebound with nonlocal statements at :12:12'])
+ self.checkChains(code, ['S -> (S -> (Call -> ()))', 'outer1 -> ()'], strict=False)
+ self.checkUseDefChains(code, 'Call <- {S, print}, Call <- {S, print}, Call <- {T, print}, S <- {S}, S <- {S}, S <- {}, S <- {}, T <- {T}, T <- {}, outer1 <- {S}, outer2 <- {T}, print <- {builtin_function_or_method}, print <- {builtin_function_or_method}, print <- {builtin_function_or_method}', strict=False)
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_scopes04bis(self):
+ code = '''\
+def outer1():
+ def outer2[T]():
+ def inner1():
+ print(T)
+ inner1()
+ outer2()
+outer1()
+'''
+ self.check_message(code, [])
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_scopes05(self):
+ code = """\
+from typing import Sequence
+
+class Outer:
+ class Private:
+ pass
+
+ # If the type parameter scope was like a traditional scope,
+ # the base class 'Private' would not be accessible here.
+ class Inner[T](Private, Sequence[T]):
+ pass
+
+ # Likewise, 'Inner' would not be available in these type annotations.
+ def method1[T](self, a: Inner[T]) -> Inner[T]:
+ return a
+"""
+ self.checkChains(code, ['Sequence -> (Sequence -> (Subscript -> (Inner -> (Inner -> (Subscript -> ()), Inner -> (Subscript -> ())))))',
+ 'Outer -> ()'])
+ self.checkUseDefChains(code, 'Inner <- {Inner}, Inner <- {Inner}, Inner <- {Private, Subscript, T}, Private <- {Private}, '
+ 'Sequence <- {Sequence}, Subscript <- {Inner, T}, Subscript <- {Inner, T}, Subscript <- {Sequence, T}, '
+ 'T <- {T}, T <- {T}, T <- {T}, a <- {a}, a <- {}, method1 <- {T}, self <- {}')
+
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_scopes06(self):
+ code = """\
+from typing import Sequence
+from x import decorator
+
+T = 0
+
+@decorator(T) # Argument expression `T` evaluates to 0
+class ClassA[T](Sequence[T]):
+ T = 1
+
+ # All methods below should result in a type checker error
+ # "type parameter 'T' already in use" because they are using the
+ # type parameter 'T', which is already in use by the outer scope
+ # 'ClassA'.
+ def method1[T](self):
+ ...
+
+ def method2[T](self, x = T): # Parameter 'x' gets default value of 1
+ ...
+
+ def method3[T](self, x: T): # Parameter 'x' has type T (scoped to method3)
+ ...
+
+"""
+ self.checkChains(code, ['Sequence -> (Sequence -> (Subscript -> (ClassA -> ())))',
+ 'decorator -> (decorator -> (Call -> (ClassA -> ())))',
+ 'T -> (T -> (Call -> (ClassA -> ())))',
+ 'ClassA -> ()'])
+ self.checkUseDefChains(code, 'Call <- {T, decorator}, ClassA <- {Call, Subscript, T}, '
+ 'Sequence <- {Sequence}, Subscript <- {Sequence, T}, T <- {T}, T <- {T}, '
+ 'T <- {T}, T <- {T}, T <- {}, T <- {}, decorator <- {decorator}, method1 <- {T}, '
+ 'method2 <- {T, T}, method3 <- {T}, self <- {}, self <- {}, self <- {}, x <- {}, x <- {}')
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_scopes07(self):
+ code = """\
+T = 0
+
+# T refers to the global variable
+print(T) # Prints 0
+
+class Outer[T]:
+ T = 1
+
+ # T refers to the local variable scoped to class 'Outer'
+ print(T) # Prints 1
+
+ class Inner1:
+ T = 2
+
+ # T refers to the local type variable within 'Inner1'
+ print(T) # Prints 2
+
+ def inner_method(self):
+ # T refers to the type parameter scoped to class 'Outer';
+ # If 'Outer' did not use the new type parameter syntax,
+ # this would instead refer to the global variable 'T'
+ print(T) # Prints 'T'
+
+ def outer_method(self):
+ T = 3
+
+ # T refers to the local variable within 'outer_method'
+ print(T) # Prints 3
+
+ def inner_func():
+ # T refers to the variable captured from 'outer_method'
+ print(T) # Prints 3
+"""
+ self.checkChains(code, ['T -> (T -> (Call -> ()))', 'Outer -> ()'])
+ self.checkUseDefChains(code, 'Call <- {T, print}, Call <- {T, print}, Call <- {T, print}, Call <- {T, print}, '
+ 'Call <- {T, print}, Call <- {T, print}, Outer <- {T}, T <- {T}, T <- {T}, T <- {T}, '
+ 'T <- {T}, T <- {T}, T <- {T}, T <- {}, T <- {}, T <- {}, T <- {}, print <- {builtin_function_or_method}, '
+ 'print <- {builtin_function_or_method}, print <- {builtin_function_or_method}, '
+ 'print <- {builtin_function_or_method}, print <- {builtin_function_or_method}, '
+ 'print <- {builtin_function_or_method}, self <- {}, self <- {}')
+
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_scopes08(self):
+ code = '''\
+from x import decorator
+T = 1
+@decorator
+def f[decorator, T: int, U: (int, str), *Ts, **P](
+ y: U,
+ x: T = T, # default values are evaluated outside the def695 scope
+ *args: *Ts,
+ **kwargs: P.kwargs,
+) -> T:
+ return x
+'''
+ self.checkChains(code, ['decorator -> (decorator -> ())',
+ 'T -> (T -> (f -> ()))',
+ 'f -> ()'])
+ self.checkUseDefChains(code, 'Attribute <- {P}, P <- {P}, Starred <- {Ts}, T <- {T}, T <- {T}, T <- {T}, T <- {}, '
+ 'Ts <- {Ts}, Tuple <- {int, str}, U <- {U}, args <- {}, decorator <- {decorator}, '
+ 'f <- {P, T, T, Ts, U, decorator}, int <- {type}, int <- {type}, kwargs <- {}, '
+ 'str <- {type}, x <- {x}, x <- {}, y <- {}')
+
+ @skipIf(sys.version_info < (3,12), "Python 3.12 syntax")
+ def test_pep695_scopes09(self):
+ code = '''\
+from x import decorator
+@decorator
+class B[decorator](object):
+ print(decorator)
+'''
+ self.checkChains(code, ['decorator -> (decorator -> (B -> ()))',
+ 'B -> ()'])
+
+ @skipIf(sys.version_info < (3,10), "Python 3.10 syntax")
+ def test_match_value(self):
+ code = '''
+command = 123
+match command:
+ case 123 as b:
+ b+=1
+ '''
+ self.checkChains(code, ['command -> (command -> ())',
+ 'b -> (b -> ())'])
+
+ @skipIf(sys.version_info < (3,10), "Python 3.10 syntax")
+ def test_match_list(self):
+ code = '''
+command = 'go there'
+match command.split():
+ case ["go", direction]:
+ print(direction)
+ case _:
+ raise ValueError("Sorry")
+ '''
+ self.checkChains(code, ['command -> (command -> (Attribute -> (Call -> ())))',
+ 'direction -> (MatchSequence -> (), direction -> (Call -> ()))'])
+
+ @skipIf(sys.version_info < (3,10), "Python 3.10 syntax")
+ def test_match_list_star(self):
+ code = '''
+command = 'drop'
+match command.split():
+ case ["go", direction]: ...
+ case ["drop", *objects]:
+ print(objects)
+ '''
+ self.checkChains(code, ['command -> (command -> (Attribute -> (Call -> ())))',
+ 'direction -> (MatchSequence -> ())',
+ 'objects -> (MatchSequence -> (), objects -> (Call -> ()))'])
+
+ @skipIf(sys.version_info < (3,10), "Python 3.10 syntax")
+ def test_match_dict(self):
+ code = '''
+ui = object()
+action = dict(text='')
+match action:
+ case {"text": str(message), "color": str(c), **rest}:
+ ui.set_text_color(c)
+ ui.display(message)
+ print(rest)
+ case {"sleep": float(duration)}:
+ ui.wait(duration)
+ case {"sound": str(url), "format": "ogg"}:
+ ui.play(url)
+ case {"sound": _, "format": _}:
+ raise ValueError("Unsupported audio format")
+print(c)
+ '''
+ self.checkChains(code, ['ui -> (ui -> (Attribute -> (Call -> ())), ui -> (Attribute -> (Call -> ())), ui -> (Attribute -> (Call -> ())), ui -> (Attribute -> (Call -> ())))',
+ 'action -> (action -> ())',
+ 'message -> (MatchClass -> (rest -> (rest -> (Call -> ()))), message -> (Call -> ()))',
+ 'c -> (MatchClass -> (rest -> (rest -> (Call -> ()))), c -> (Call -> ()), c -> (Call -> ()))',
+ 'rest -> (rest -> (Call -> ()))',
+ 'duration -> (MatchClass -> (MatchMapping -> ()), duration -> (Call -> ()))',
+ 'url -> (MatchClass -> (MatchMapping -> ()), url -> (Call -> ()))'])
+
+ @skipIf(sys.version_info < (3,10), "Python 3.10 syntax")
+ def test_match_class_rebinds_attrs(self):
+
+ code = '''
+from dataclasses import dataclass
+
+@dataclass
+class Point:
+ x: int
+ y: int
+
+point = Point(-2,1)
+match point:
+ case Point(x=0, y=0):
+ print("Origin")
+ case Point(x=0, y=y):
+ print(f"Y={y}")
+ case Point(x=x, y=0):
+ print(f"X={x}")
+ case Point(x=x, y=y):
+ print("Somewhere else")
+ case _:
+ print("Not a point")
+print(x, y)
+ '''
+ self.checkChains(
+ code, ['dataclass -> (dataclass -> (Point -> (Point -> (Call -> ()), Point -> (MatchClass -> ()), Point -> (MatchClass -> ()), Point -> (MatchClass -> ()), Point -> (MatchClass -> ()))))',
+ 'Point -> (Point -> (Call -> ()), Point -> (MatchClass -> ()), Point -> (MatchClass -> ()), Point -> (MatchClass -> ()), Point -> (MatchClass -> ()))',
+ 'point -> (point -> ())',
+ 'y -> (MatchClass -> (), y -> (FormattedValue -> (JoinedStr -> (Call -> ()))), y -> (Call -> ()))',
+ 'x -> (MatchClass -> (), x -> (FormattedValue -> (JoinedStr -> (Call -> ()))), x -> (Call -> ()))',
+ 'x -> (MatchClass -> (), x -> (Call -> ()))',
+ 'y -> (MatchClass -> (), y -> (Call -> ()))'])
+
+
+class TestDefUseChainsStdlib(TestDefUseChains):
+ ast = _ast
+
+
+class TestUseDefChains(TestCase):
+ ast = _gast
+ checkChains = TestDefUseChains.checkUseDefChains
def test_simple_expression(self):
code = "a = 1; a"
@@ -1267,3 +1785,7 @@ def test_simple_expression(self):
def test_call(self):
code = "from foo import bar; bar(1, 2)"
self.checkChains(code, "Call <- {Constant, Constant, bar}, bar <- {bar}")
+
+class TestUseDefChainsStdlib(TestDefUseChains):
+ ast = _ast
+
diff --git a/tests/test_definitions.py b/tests/test_definitions.py
index e9f956a..c5ee0c9 100644
--- a/tests/test_definitions.py
+++ b/tests/test_definitions.py
@@ -1,23 +1,16 @@
from textwrap import dedent
from unittest import TestCase
-import gast as ast
-import beniget
+import gast as _gast
+import ast as _ast
import sys
-
-class StrictDefUseChains(beniget.DefUseChains):
- def unbound_identifier(self, name, node):
- raise RuntimeError(
- "W: unbound identifier '{}' at {}:{}".format(
- name, node.lineno, node.col_offset
- )
- )
-
+from .test_chains import getStrictDefUseChains
class TestGlobals(TestCase):
+ ast = _gast
def checkGlobals(self, code, ref):
- node = ast.parse(code)
- c = StrictDefUseChains()
+ node = self.ast.parse(code)
+ c = getStrictDefUseChains(node)()
c.visit(node)
self.assertEqual(c.dump_definitions(node), ref)
@@ -278,13 +271,16 @@ def testGlobalLambda(self):
code = "lambda x: x"
self.checkGlobals(code, [])
+class TestGlobalsStdlib(TestGlobals):
+ ast = _ast
class TestClasses(TestCase):
+ ast = _gast
def checkClasses(self, code, ref):
- node = ast.parse(code)
- c = StrictDefUseChains()
+ node = self.ast.parse(code)
+ c = getStrictDefUseChains(node)()
c.visit(node)
- classes = [n for n in node.body if isinstance(n, ast.ClassDef)]
+ classes = [n for n in node.body if isinstance(n, self.ast.ClassDef)]
assert len(classes) == 1, "only one top-level function per test case"
cls = classes[0]
self.assertEqual(c.dump_definitions(cls), ref)
@@ -293,13 +289,16 @@ def test_class_method_assign(self):
code = "class C:\n def foo(self):pass\n bar = foo"
self.checkClasses(code, ["bar", "foo"])
+class TestClassesStdlib(TestClasses):
+ ast = _ast
class TestLocals(TestCase):
+ ast = _gast
def checkLocals(self, code, ref):
- node = ast.parse(dedent(code))
- c = StrictDefUseChains()
+ node = self.ast.parse(dedent(code))
+ c = getStrictDefUseChains(node)()
c.visit(node)
- functions = [n for n in node.body if isinstance(n, ast.FunctionDef)]
+ functions = [n for n in node.body if isinstance(n, self.ast.FunctionDef)]
assert len(functions) == 1, "only one top-level function per test case"
f = functions[0]
self.assertEqual(c.dump_definitions(f), ref)
@@ -394,15 +393,20 @@ def foo(a):
else: b = a"""
self.checkLocals(code, ["a", "b"])
+class TestLocalsStdlib(TestLocals):
+ ast = _ast
+
class TestDefIsLive(TestCase):
+ ast = _gast
+
def checkLocals(self, c, node, ref, only_live=False):
self.assertEqual(sorted(c._dump_locals(node, only_live=only_live)),
sorted(ref))
def checkLiveLocals(self, code, livelocals, locals):
- node = ast.parse(dedent(code))
- c = StrictDefUseChains()
+ node = self.ast.parse(dedent(code))
+ c = getStrictDefUseChains(node)()
c.visit(node)
self.checkLocals(c, node, locals)
self.checkLocals(c, node, livelocals, only_live=True)
@@ -552,3 +556,5 @@ def test_more_loops(self):
self.checkLiveLocals(code, ['b:2,6', 'v:9,10,4', 'k:10,13'],
['b:2,6', 'v:9,10,4', 'k:10,13'])
+class TestDefIsLiveStdlib(TestDefIsLive):
+ ast = _ast
\ No newline at end of file
diff --git a/tox.ini b/tox.ini
index 3e9111b..6ae1ab3 100644
--- a/tox.ini
+++ b/tox.ini
@@ -2,5 +2,6 @@
envlist = py27,py36,py37,py38,py39,py310,py311,py312
[testenv]
deps =
+ git+https://github.com/serge-sans-paille/gast.git
pytest
-commands=pytest beniget/ tests/ --doctest-modules
\ No newline at end of file
+commands=pytest beniget/ tests/ --doctest-modules