diff --git a/.gitignore b/.gitignore index 4ca31db..ee12072 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ beniget.egg-info .pytest_cache .vscode +build +.tox \ No newline at end of file diff --git a/README.rst b/README.rst index 85746ca..426b431 100644 --- a/README.rst +++ b/README.rst @@ -4,8 +4,8 @@ Gast, Beniget! Beniget is a collection of Compile-time analyse on Python Abstract Syntax Tree(AST). It's a building block to write static analyzer or compiler for Python. -Beniget relies on `gast `_ to provide a cross -version abstraction of the AST, effectively working across all Python 3 versions greater than 3.6. +Beniget is compatible with the standard library AST as well as the AST generated by `gast `_, +which is a cross version abstraction of the AST. API --- @@ -16,6 +16,12 @@ Basically Beniget provides three analyse: - ``beniget.DefUseChains`` that maps each node to the list of definition points in that node; - ``beniget.UseDefChains`` that maps each node to the list of possible definition of that node. +Alternatives working with standard library AST: + +- ``beniget.standard.Ancestors`` that maps each node to the list of enclosing nodes; +- ``beniget.standard.DefUseChains`` that maps each node to the list of definition points in that node; +- ``beniget.standard.UseDefChains`` that maps each node to the list of possible definition of that node. + See sample usages and/or run ``pydoc beniget`` for more information :-). diff --git a/beniget/beniget.py b/beniget/beniget.py index 9f45161..29b1edb 100644 --- a/beniget/beniget.py +++ b/beniget/beniget.py @@ -1,30 +1,52 @@ from collections import defaultdict from contextlib import contextmanager +import itertools import sys +import ast as _ast import gast as ast from .ordered_set import ordered_set +_ClassOrFunction = set(('ClassDef', 'FunctionDef', 'AsyncFunctionDef')) +_Comp = set(('DictComp', 'ListComp', 'SetComp', 'GeneratorExp')) +_ClosedScopes = set(('FunctionDef', 'AsyncFunctionDef', + 'Lambda', 'DictComp', 'ListComp', + 'SetComp', 'GeneratorExp', 'def695')) +_TypeVarLike = set(('TypeVar', 'TypeVarTuple', 'ParamSpec')) +_HasName = set((*_ClassOrFunction, *_TypeVarLike)) + class Ancestors(ast.NodeVisitor): """ Build the ancestor tree, that associates a node to the list of node visited from the root node (the Module) to the current node - >>> import gast as ast >>> code = 'def foo(x): return x + 1' >>> module = ast.parse(code) - >>> from beniget import Ancestors >>> ancestors = Ancestors() >>> ancestors.visit(module) - >>> binop = module.body[0].body[0].value >>> for n in ancestors.parents(binop): ... print(type(n)) + + Also works with standard library nodes + + >>> import ast as _ast + >>> code = 'def foo(x): return x + 1' + >>> module = _ast.parse(code) + >>> from beniget import Ancestors + >>> ancestors = Ancestors() + >>> ancestors.visit(module) + >>> binop = module.body[0].body[0].value + >>> for n in ancestors.parents(binop): + ... print(str(type(n)).replace('_ast', 'ast')) + + + """ def __init__(self): @@ -34,7 +56,7 @@ def __init__(self): def generic_visit(self, node): self._parents[node] = list(self._current) self._current.append(node) - super(Ancestors, self).generic_visit(node) + super().generic_visit(node) self._current.pop() def parent(self, node): @@ -51,11 +73,29 @@ def parentInstance(self, node, cls): def parentFunction(self, node): return self.parentInstance(node, (ast.FunctionDef, - ast.AsyncFunctionDef)) + ast.AsyncFunctionDef, + _ast.FunctionDef, + _ast.AsyncFunctionDef)) def parentStmt(self, node): - return self.parentInstance(node, ast.stmt) + return self.parentInstance(node, _ast.stmt) +_novalue = object() +@contextmanager +def _setattrs(obj, **attrs): + """ + Provide cheap attribute polymorphism. + """ + old_values = {} + for k, v in attrs.items(): + old_values[k] = getattr(obj, k, _novalue) + setattr(obj, k, v) + yield + for k, v in old_values.items(): + if v is _novalue: + delattr(obj, k) + else: + setattr(obj, k, v) class Def(object): """ @@ -76,7 +116,7 @@ def __init__(self, node): """ def add_user(self, node): - assert isinstance(node, Def) + assert isinstance(node, Def), node self._users.add(node) def name(self): @@ -84,19 +124,25 @@ def name(self): If the node associated to this Def has a name, returns this name. Otherwise returns its type """ - if isinstance(self.node, (ast.ClassDef, - ast.FunctionDef, - ast.AsyncFunctionDef)): + typename = type(self.node).__name__ + if typename in _HasName: return self.node.name - elif isinstance(self.node, ast.Name): + elif typename == 'Name': return self.node.id - elif isinstance(self.node, ast.alias): + elif typename == 'alias': base = self.node.name.split(".", 1)[0] return self.node.asname or base + elif typename in ('MatchStar', 'MatchAs') and self.node.name: + return self.node.name + elif typename == 'MatchMapping' and self.node.rest: + return self.node.rest + elif typename == 'arg': + return self.node.arg + elif typename == 'ExceptHandler' and self.node.name: + return self.node.name elif isinstance(self.node, tuple): return self.node[1] - else: - return type(self.node).__name__ + return typename def users(self): """ @@ -144,7 +190,7 @@ def collect_future_imports(node): """ Returns a set of future imports names for the given ast module. """ - assert isinstance(node, ast.Module) + assert type(node).__name__ == 'Module' cf = _CollectFutureImports() cf.visit(node) return cf.FutureImports @@ -185,6 +231,9 @@ def visit_Constant(self, node): def generic_visit(self, node): raise _StopTraversal() + def visit_Str(self, node): + pass + class CollectLocals(ast.NodeVisitor): def __init__(self): self.Locals = set() @@ -204,7 +253,7 @@ def visit_Nonlocal(self, node): visit_Global = visit_Nonlocal def visit_Name(self, node): - if isinstance(node.ctx, ast.Store) and node.id not in self.NonLocals: + if type(node.ctx).__name__ == 'Store' and node.id not in self.NonLocals: self.Locals.add(node.id) def skip(self, _): @@ -224,16 +273,32 @@ def visit_ImportFrom(self, node): for alias in node.names: self.Locals.add(alias.asname or alias.name) +class CollectLocalsdef695(CollectLocals): + + visit_TypeVar = visit_ParamSpec = visit_TypeVarTuple = CollectLocals.visit_FunctionDef + def collect_locals(node): ''' Compute the set of identifiers local to a given node. This is meant to emulate a call to locals() ''' - visitor = CollectLocals() + if isinstance(node, def695): + # workaround for the new implicit scope created by type params and co. + visitor = CollectLocalsdef695() + else: + visitor = CollectLocals() visitor.generic_visit(node) return visitor.Locals +class def695(ast.stmt): + """ + Special statement to represent the PEP-695 lexical scopes. + """ + _fields = ('body', 'd') + def __init__(self, body, d): + self.body = body # list of type params + self.d = d # the wrapped definition node class DefUseChains(ast.NodeVisitor): """ @@ -289,7 +354,9 @@ def __init__(self, filename=None): # be defined in another path of the control flow (esp. in loop) self._undefs = [] - # stack of nodes starting a scope: class, module, function, generator expression, comprehension... + # stack of nodes starting a scope: + # class, module, function, generator expression, + # comprehension, def695. self._scopes = [] self._breaks = [] @@ -321,11 +388,11 @@ def _dump_locals(self, node, only_live=False): for d in self.locals[node]: if not only_live or d.islive: groupped[d.name()].append(d) - return ['{}:{}'.format(name, ','.join([str(getattr(d.node, 'lineno', -1)) for d in defs])) \ + return ['{}:{}'.format(name, ','.join([str(getattr(d.node, 'lineno', None)) for d in defs])) \ for name,defs in groupped.items()] def dump_definitions(self, node, ignore_builtins=True): - if isinstance(node, ast.Module) and not ignore_builtins: + if type(node).__name__ == 'Module' and not ignore_builtins: builtins = {d for d in self._builtins.values()} return sorted(d.name() for d in self.locals[node] if d not in builtins) @@ -345,7 +412,7 @@ def location(self, node): ) return " at {}{}:{}".format(filename, node.lineno, - node.col_offset) + getattr(node, 'col_offset', None),) else: return "" @@ -375,7 +442,7 @@ def invalid_name_lookup(self, name, scope, precomputed_locals, local_defs): # >>> foo() # fails, a is a local referenced before being assigned # >>> class bar: a = a # >>> bar() # ok, and `bar.a is a` - if isinstance(scope, ast.ClassDef): + if type(scope).__name__ in ('ClassDef', 'def695'): # TODO: test the def695 part of this top_level_definitions = self._definitions[0:-self._scope_depths[0]] isglobal = any((name in top_lvl_def or '*' in top_lvl_def) for top_lvl_def in top_level_definitions) @@ -421,6 +488,7 @@ def compute_defs(self, node, quiet=False): precomputed_locals = next(precomputed_locals_iter) base_scope = next(scopes_iter) defs = self._definitions[depth:] + is_def695 = isinstance(base_scope, def695) if not self.invalid_name_lookup(name, base_scope, precomputed_locals, defs): looked_up_definitions.extend(reversed(defs)) @@ -428,7 +496,9 @@ def compute_defs(self, node, quiet=False): for scope, depth, precomputed_locals in zip(scopes_iter, depths_iter, precomputed_locals_iter): - if not isinstance(scope, ast.ClassDef): + # If a def695 scope is immediately within a class scope, or within another def695 scope that is immediately within a class scope, + # then names defined in that class scope can be accessed within the def695 scope. + if type(scope).__name__ != 'ClassDef' or is_def695: defs = self._definitions[lvl + depth: lvl] if self.invalid_name_lookup(name, base_scope, precomputed_locals, defs): looked_up_definitions.append(StopIteration) @@ -462,7 +532,7 @@ def process_body(self, stmts): deadcode = False for stmt in stmts: self.visit(stmt) - if isinstance(stmt, (ast.Break, ast.Continue, ast.Raise)): + if type(stmt).__name__ in ('Break', 'Continue', 'Raise'): if not deadcode: deadcode = True self._deadcode += 1 @@ -595,7 +665,7 @@ def set_definition(self, name, dnode_or_dnodes, index=-1): # set the islive flag to False on killed Defs for d in self._definitions[index].get(name, ()): - if not isinstance(d.node, ast.AST): + if not isinstance(d.node, _ast.AST): # A builtin: we never explicitely mark the builtins as killed, since # it can be easily deducted. continue @@ -650,48 +720,91 @@ def visit_annotation(self, node): self.visit(annotation) def visit_skip_annotation(self, node): - if isinstance(node, ast.Name): + if type(node).__name__ == 'Name': self.visit_Name(node, skip_annotation=True) else: self.visit(node) - def visit_FunctionDef(self, node, step=DeclarationStep): + def visit_FunctionDef(self, node, step=DeclarationStep, in_def695=False): if step is DeclarationStep: dnode = self.chains.setdefault(node, Def(node)) self.add_to_locals(node.name, dnode) + if not in_def695: + + for kw_default in filter(None, node.args.kw_defaults): + self.visit(kw_default).add_user(dnode) + for default in node.args.defaults: + self.visit(default).add_user(dnode) + for decorator in node.decorator_list: + self.visit(decorator) + + if any(getattr(node, 'type_params', [])): + self.visit_def695(def695(body=node.type_params, d=node)) + return + if not self.future_annotations: for arg in _iter_arguments(node.args): - self.visit_annotation(arg) + annotation = getattr(arg, 'annotation', None) + if annotation: + if in_def695: + try: + _validate_annotation_body(annotation) + except SyntaxError as e : + self.warn(str(e), annotation) + continue + self.visit(annotation) else: # annotations are to be analyzed later as well currentscopes = list(self._scopes) if node.returns: - self._defered_annotations[-1].append( - (node.returns, currentscopes, None)) + try: + _validate_annotation_body(node.returns) + except SyntaxError as e : + self.warn(str(e), node.returns) + else: + self._defered_annotations[-1].append( + (node.returns, currentscopes, None)) for arg in _iter_arguments(node.args): if arg.annotation: + try: + _validate_annotation_body(arg.annotation) + except SyntaxError as e : + self.warn(str(e), arg.annotation) + continue self._defered_annotations[-1].append( (arg.annotation, currentscopes, None)) - for kw_default in filter(None, node.args.kw_defaults): - self.visit(kw_default).add_user(dnode) - for default in node.args.defaults: - self.visit(default).add_user(dnode) - for decorator in node.decorator_list: - self.visit(decorator) - if not self.future_annotations and node.returns: - self.visit(node.returns) - - self.set_definition(node.name, dnode) + if in_def695: + try: + _validate_annotation_body(node.returns) + except SyntaxError as e: + self.warn(str(e), node.returns) + else: + self.visit(node.returns) + else: + self.visit(node.returns) + + if in_def695: + # emulate this (except f is not actually defined in both scopes): + # def695 __generic_parameters_of_f(): + # T = TypeVar(name='T') + # def f(x: T) -> T: + # return x + # return f + # f = __generic_parameters_of_f() + self.set_definition(node.name, dnode, index=-2) + else: + self.set_definition(node.name, dnode) self._defered.append((node, list(self._definitions), list(self._scopes), list(self._scope_depths), list(self._precomputed_locals))) + elif step is DefinitionStep: with self.ScopeContext(node): for arg in _iter_arguments(node.args): @@ -702,23 +815,44 @@ def visit_FunctionDef(self, node, step=DeclarationStep): visit_AsyncFunctionDef = visit_FunctionDef - def visit_ClassDef(self, node): + def visit_ClassDef(self, node, in_def695=False): dnode = self.chains.setdefault(node, Def(node)) self.add_to_locals(node.name, dnode) + + if not in_def695: + for decorator in node.decorator_list: + self.visit(decorator).add_user(dnode) + + if any(getattr(node, 'type_params', [])): + self.visit_def695(def695(body=node.type_params, d=node)) + return for base in node.bases: + if in_def695: + try: + _validate_annotation_body(base) + except SyntaxError as e: + self.warn(str(e), base) + continue self.visit(base).add_user(dnode) for keyword in node.keywords: + if in_def695: + try: + _validate_annotation_body(keyword) + except SyntaxError as e: + self.warn(str(e), keyword) + continue self.visit(keyword.value).add_user(dnode) - for decorator in node.decorator_list: - self.visit(decorator).add_user(dnode) with self.ScopeContext(node): self.set_definition("__class__", Def("__class__")) self.process_body(node.body) - self.set_definition(node.name, dnode) - + if in_def695: + # see comment in visit_FunctionDef + self.set_definition(node.name, dnode, index=-2) + else: + self.set_definition(node.name, dnode) def visit_Return(self, node): if node.value: @@ -750,13 +884,18 @@ def visit_AnnAssign(self, node): if not self.future_annotations: self.visit(node.annotation) else: - self._defered_annotations[-1].append( - (node.annotation, list(self._scopes), None)) + try: + _validate_annotation_body(node.annotation) + except SyntaxError as e: + self.warn(str(e), node.annotation) + else: + self._defered_annotations[-1].append( + (node.annotation, list(self._scopes), None)) self.visit(node.target) def visit_AugAssign(self, node): dvalue = self.visit(node.value) - if isinstance(node.target, ast.Name): + if type(node.target).__name__ == 'Name': ctx, node.target.ctx = node.target.ctx, ast.Load() dtarget = self.visit(node.target) dvalue.add_user(dtarget) @@ -773,6 +912,47 @@ def visit_AugAssign(self, node): self.locals[self._scopes[-1]].append(dtarget) else: self.visit(node.target).add_user(dvalue) + + def visit_TypeAlias(self, node, in_def695=False): + # Generic type aliases: + # type Alias[T: int] = list[T] + + # Equivalent to: + # def695 __generic_parameters_of_Alias(): + # def695 __evaluate_T_bound(): + # return int + # T = __make_typevar_with_bound(name='T', evaluate_bound=__evaluate_T_bound) + # def695 __evaluate_Alias(): + # return list[T] + # return __make_typealias(name='Alias', type_params=(T,), evaluate_value=__evaluate_Alias) + # Alias = __generic_parameters_of_Alias() + + if type(node.name).__name__ == 'Name': + dname = self.chains.setdefault(node.name, Def(node.name)) + self.add_to_locals(node.name.id, dname) + + if not in_def695 and any(getattr(node, 'type_params', [])): + self.visit_def695(def695(body=node.type_params, d=node)) + return + + dnode = self.chains.setdefault(node, Def(node)) + try: + _validate_annotation_body(node.value) + except SyntaxError as e: + self.warn(str(e), node.value) + else: + self._defered_annotations[-1].append( + (node.value, list(self._scopes), None)) + + if in_def695: + # see comment in visit_FunctionDef + self.set_definition(node.name.id, dname, index=-2) + else: + self.set_definition(node.name.id, dname) + + return dnode + else: + raise NotImplementedError() def visit_For(self, node): self.visit(node.iter) @@ -912,17 +1092,17 @@ def visit_Try(self, node): self.extend_definition(hd, handler_def[hd]) self.process_body(node.finalbody) - + def visit_Assert(self, node): self.visit(node.test) if node.msg: self.visit(node.msg) - def add_to_locals(self, name, dnode): + def add_to_locals(self, name, dnode, index=-1): if any(name in _globals for _globals in self._globals): self.set_or_extend_global(name, dnode) - else: - self.locals[self._scopes[-1]].append(dnode) + elif dnode not in self.locals[self._scopes[index]]: + self.locals[self._scopes[index]].append(dnode) def visit_Import(self, node): for alias in node.names: @@ -946,10 +1126,16 @@ def visit_Global(self, node): def visit_Nonlocal(self, node): for name in node.names: - for d in reversed(self._definitions[:-1]): + for i, d in enumerate(reversed(self._definitions)): + if i == 0: + continue if name not in d: continue else: + if isinstance(self._scopes[-i-1], def695): + # see https://docs.python.org/3.12/reference/executionmodel.html#annotation-scopes + self.warn("names defined in annotation scopes cannot be rebound with nonlocal statements", node) + break # this rightfully creates aliasing self.set_definition(name, d[name]) break @@ -959,7 +1145,100 @@ def visit_Nonlocal(self, node): def visit_Expr(self, node): self.generic_visit(node) - # expr + # pattern matching + + def visit_Match(self, node): + + self.visit(node.subject) + + defs = [] + for kase in node.cases: + if kase.guard: + self.visit(kase.guard) + self.visit(kase.pattern) + + with self.DefinitionContext(self._definitions[-1].copy()) as case_defs: + self.process_body(kase.body) + defs.append(case_defs) + + if not defs: + return + if len(defs) == 1: + body_defs, orelse_defs, rest = defs[0], [], [] + else: + body_defs, orelse_defs, rest = defs[0], defs[1], defs[2:] + while True: + # merge defs, like in if-else but repeat the process for x branches + for d in body_defs: + if d in orelse_defs: + self.set_definition(d, body_defs[d] + orelse_defs[d]) + else: + self.extend_definition(d, body_defs[d]) + for d in orelse_defs: + if d not in body_defs: + self.extend_definition(d, orelse_defs[d]) + if not rest: + break + body_defs = self._definitions[-1] + orelse_defs, rest = rest[0], rest[1:] + + def visit_MatchValue(self, node): + dnode = self.chains.setdefault(node, Def(node)) + self.visit(node.value) + return dnode + + visit_MatchSingleton = visit_MatchValue + + def visit_MatchSequence(self, node): + # mimics a list + with _setattrs(node, ctx=ast.Load(), elts=node.patterns): + return self.visit_List(node) + + def visit_MatchMapping(self, node): + dnode = self.chains.setdefault(node, Def(node)) + with _setattrs(node, values=node.patterns): + # mimics a dict + self.visit_Dict(node) + if node.rest: + with _setattrs(node, id=node.rest, ctx=ast.Store(), annotation=None): + self.visit_Name(node) + return dnode + + def visit_MatchClass(self, node): + # mimics a call + dnode = self.chains.setdefault(node, Def(node)) + self.visit(node.cls).add_user(dnode) + for arg in node.patterns: + self.visit(arg).add_user(dnode) + for kw in node.kwd_patterns: + self.visit(kw).add_user(dnode) + return dnode + + def visit_MatchStar(self, node): + dnode = self.chains.setdefault(node, Def(node)) + if node.name: + # mimics store name + with _setattrs(node, id=node.name, ctx=ast.Store(), annotation=None): + self.visit_Name(node) + return dnode + + def visit_MatchAs(self, node): + dnode = self.chains.setdefault(node, Def(node)) + if node.pattern: + self.visit(node.pattern) + if node.name: + with _setattrs(node, id=node.name, ctx=ast.Store(), annotation=None): + self.visit_Name(node) + return dnode + + def visit_MatchOr(self, node): + dnode = self.chains.setdefault(node, Def(node)) + for pat in node.patterns: + self.visit(pat).add_user(dnode) + return dnode + + # expressions + def visit_BoolOp(self, node): dnode = self.chains.setdefault(node, Def(node)) for value in node.values: @@ -1112,7 +1391,7 @@ def visit_Subscript(self, node): return dnode def visit_Starred(self, node): - if isinstance(node.ctx, ast.Store): + if type(node.ctx).__name__ == 'Store': return self.visit(node.value) else: dnode = self.chains.setdefault(node, Def(node)) @@ -1122,25 +1401,21 @@ def visit_Starred(self, node): def visit_NamedExpr(self, node): dnode = self.chains.setdefault(node, Def(node)) self.visit(node.value).add_user(dnode) - if isinstance(node.target, ast.Name): + if type(node.target).__name__ == 'Name': self.visit_Name(node.target, named_expr=True) return dnode - def is_in_current_scope(self, name): - return any(name in defs - for defs in self._definitions[self._scope_depths[-1]:]) - def _first_non_comprehension_scope(self): index = -1 enclosing_scope = self._scopes[index] - while isinstance(enclosing_scope, (ast.DictComp, ast.ListComp, - ast.SetComp, ast.GeneratorExp)): + while type(enclosing_scope).__name__ in _Comp: index -= 1 enclosing_scope = self._scopes[index] return index, enclosing_scope def visit_Name(self, node, skip_annotation=False, named_expr=False): - if isinstance(node.ctx, (ast.Param, ast.Store)): + ctx_typename = type(node.ctx).__name__ + if ctx_typename in ('Param', 'Store'): dnode = self.chains.setdefault(node, Def(node)) # FIXME: find a smart way to merge the code below with add_to_locals if any(node.id in _globals for _globals in self._globals): @@ -1151,7 +1426,7 @@ def visit_Name(self, node, skip_annotation=False, named_expr=False): index, enclosing_scope = (self._first_non_comprehension_scope() if named_expr else (-1, self._scopes[-1])) - if index < -1 and isinstance(enclosing_scope, ast.ClassDef): + if index < -1 and type(enclosing_scope).__name__ == 'ClassDef': # invalid named expression, not calling set_definition. self.warn('assignment expression within a comprehension ' 'cannot be used in a class body', node) @@ -1162,11 +1437,10 @@ def visit_Name(self, node, skip_annotation=False, named_expr=False): self.locals[self._scopes[index]].append(dnode) # Name.annotation is a special case because of gast - if node.annotation is not None and not skip_annotation and not self.future_annotations: + if getattr(node, 'annotation', None) is not None and not skip_annotation and not self.future_annotations: self.visit(node.annotation) - - elif isinstance(node.ctx, (ast.Load, ast.Del)): + elif ctx_typename in ('Load', 'Del'): node_in_chains = node in self.chains if node_in_chains: dnode = self.chains[node] @@ -1185,26 +1459,29 @@ def visit_Destructured(self, node): dnode = self.chains.setdefault(node, Def(node)) tmp_store = ast.Store() for elt in node.elts: - if isinstance(elt, ast.Name): + elt_typename = type(elt).__name__ + if elt_typename == 'Name': tmp_store, elt.ctx = elt.ctx, tmp_store self.visit(elt) tmp_store, elt.ctx = elt.ctx, tmp_store - elif isinstance(elt, (ast.Subscript, ast.Starred, ast.Attribute)): + elif elt_typename in ('Subscript', 'Starred', 'Attribute'): self.visit(elt) - elif isinstance(elt, (ast.List, ast.Tuple)): + elif elt_typename in ('List', 'Tuple'): self.visit_Destructured(elt) return dnode def visit_List(self, node): - if isinstance(node.ctx, ast.Load): + if type(node.ctx).__name__ == 'Load': dnode = self.chains.setdefault(node, Def(node)) for elt in node.elts: self.visit(elt).add_user(dnode) return dnode # unfortunately, destructured node are marked as Load, # only the parent List/Tuple is marked as Store - elif isinstance(node.ctx, ast.Store): + elif type(node.ctx).__name__ == 'Store': return self.visit_Destructured(node) + else: + raise NotImplementedError() visit_Tuple = visit_List @@ -1219,6 +1496,49 @@ def visit_Slice(self, node): if node.step: self.visit(node.step).add_user(dnode) return dnode + + # type params + + def visit_def695(self, node): + # We don't use two steps here because the declaration + # step is the same as definition step for def695's + # 1.type parameters of generic type aliases, + # 2.type parameters and annotations of generic functions and + # 3.type parameters and base class expressions of generic classes + # the rest is evaluated as defered annotations: + # 4.the value of generic type aliases + # 5.the bounds of type variables + # 6.the constraints of type variables + + # introduce the new scope + dnode = self.chains.setdefault(node.d, Def(node.d)) + + with self.ScopeContext(node): + # visit the type params + for p in node.body: + try: + _validate_annotation_body(p) + except SyntaxError as e: + self.warn(str(e), p) + else: + self.visit(p).add_user(dnode) + # then visit the actual node while + # being in the def695 scope. + visitor = getattr(self, "visit_{}".format(type(node.d).__name__)) + visitor(node.d, in_def695=True) + + def visit_TypeVar(self, node): + dnode = self.chains.setdefault(node, Def(node)) + self.set_definition(node.name, dnode) + self.add_to_locals(node.name, dnode) + + if type(node).__name__ == 'TypeVar' and node.bound: + self._defered_annotations[-1].append( + (node.bound, list(self._scopes), None)) + + return dnode + + visit_ParamSpec = visit_TypeVarTuple = visit_TypeVar # misc @@ -1248,9 +1568,7 @@ def visit_excepthandler(self, node): self.process_body(node.body) return dnode - def visit_arguments(self, node): - for arg in _iter_arguments(node): - self.visit(arg) + # visit_arguments is not implemented on purpose def visit_withitem(self, node): dnode = self.chains.setdefault(node, Def(node)) @@ -1267,24 +1585,43 @@ def _validate_comprehension(node): """ iter_names = set() # comprehension iteration variables for gen in node.generators: - for namedexpr in (n for n in ast.walk(gen.iter) if isinstance(n, ast.NamedExpr)): + for namedexpr in (n for n in ast.walk(gen.iter) if type(n).__name__ == 'NamedExpr'): raise SyntaxError('assignment expression cannot be used ' 'in a comprehension iterable expression') iter_names.update(n.id for n in ast.walk(gen.target) - if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Store)) - for namedexpr in (n for n in ast.walk(node) if isinstance(n, ast.NamedExpr)): + if type(n).__name__ == 'Name' and type(n.ctx).__name__ == 'Store') + for namedexpr in (n for n in ast.walk(node) if type(n).__name__ == 'NamedExpr'): bound = getattr(namedexpr.target, 'id', None) if bound in iter_names: raise SyntaxError('assignment expression cannot rebind ' "comprehension iteration variable '{}'".format(bound)) +_node_type_to_human_name = { + 'NamedExpr': 'assignment expression', + 'Yield': 'yield keyword', + 'YieldFrom': 'yield keyword', + 'Await': 'await keyword' +} + +def _validate_annotation_body(node): + """ + Raises SyntaxError if: + - the warlus operator is used + - the yield/ yield from statement is used + - the await keyword is used + """ + for illegal in (n for n in ast.walk(node) if type(n).__name__ in + ('NamedExpr', 'Yield', 'YieldFrom', 'Await')): + name = _node_type_to_human_name.get(type(illegal).__name__, 'current syntax') + raise SyntaxError(f'{name} cannot be used in annotation-like scopes') + def _iter_arguments(args): """ Yields all arguments of the given ast.arguments instance. """ for arg in args.args: yield arg - for arg in args.posonlyargs: + for arg in getattr(args, 'posonlyargs', ()): yield arg if args.vararg: yield args.vararg @@ -1345,7 +1682,15 @@ def lookup_annotation_name_defs(name, heads, locals_map): try: return _lookup(name, scopes, locals_map) except LookupError: - raise LookupError("'{}' not found in {}, might be a builtin".format(name, heads[-1])) + if name in BuiltinsSrc: + raise LookupError(f'{name} is a builtin') + try: + _lookup(name, scopes, locals_map, only_live=False) + except LookupError: + defined_names = [d.name() for s in scopes for d in locals_map[s]] + raise LookupError("'{}' not found in scopes: {} (heads={}) (available names={})".format(name, scopes, heads, defined_names)) + else: + raise LookupError("'{}' is killed".format(name)) def _get_lookup_scopes(heads): # heads[-1] is the direct enclosing scope and heads[0] is the module. @@ -1355,33 +1700,33 @@ def _get_lookup_scopes(heads): heads = list(heads) # avoid modifying the list (important) try: - direct_scope = heads.pop(-1) # this scope is the only one that can be a class + direct_scopes = [heads.pop(-1)] # this scope is the only one that can be a class, expect in case of the presence of def695 except IndexError: raise ValueError('invalid heads: must include at least one element') try: global_scope = heads.pop(0) except IndexError: # we got only a global scope - return [direct_scope] + return direct_scopes + else: + if heads and isinstance(direct_scopes[-1], def695) and type(heads[-1]).__name__ == 'ClassDef': + direct_scopes.insert(0, heads.pop(-1)) # more of less modeling what's described here. # https://github.com/gvanrossum/gvanrossum.github.io/blob/main/formal/scopesblog.md - other_scopes = [s for s in heads if isinstance(s, ( - ast.FunctionDef, ast.AsyncFunctionDef, - ast.Lambda, ast.DictComp, ast.ListComp, - ast.SetComp, ast.GeneratorExp))] - return [global_scope] + other_scopes + [direct_scope] - -def _lookup(name, scopes, locals_map): - context = scopes.pop() + other_scopes = [s for s in heads if type(s).__name__ in _ClosedScopes] + return [global_scope] + other_scopes + direct_scopes + +def _lookup(name, scopes, locals_map, only_live=True): + context = scopes[-1] defs = [] for loc in locals_map.get(context, ()): - if loc.name() == name and loc.islive: + if loc.name() == name and (loc.islive if only_live else True): defs.append(loc) if defs: return defs - elif len(scopes)==0: + elif len(scopes)==1: raise LookupError() - return _lookup(name, scopes, locals_map) + return _lookup(name, scopes[:-1], locals_map) class UseDefChains(object): """ @@ -1394,7 +1739,7 @@ class UseDefChains(object): def __init__(self, defuses): self.chains = {} for chain in defuses.chains.values(): - if isinstance(chain.node, ast.Name): + if type(chain.node).__name__ == 'Name': self.chains.setdefault(chain.node, []) for use in chain.users(): self.chains.setdefault(use.node, []).append(chain) diff --git a/beniget/standard.py b/beniget/standard.py new file mode 100644 index 0000000..81b480b --- /dev/null +++ b/beniget/standard.py @@ -0,0 +1,61 @@ +""" +This module offers the same three analyses, but designed to be run on standard library nodes. +""" + +import ast +from beniget import beniget, Ancestors + +__all__ = ('Ancestors', 'Def', 'DefUseChains', 'UseDefChains') + +class DefUseChains(beniget.DefUseChains): + + def visit_skip_annotation(self, node): + if isinstance(node, ast.arg): + return self.visit_arg(node, skip_annotation=True) + return super().visit_skip_annotation(node) + + def visit_ExceptHandler(self, node): + if isinstance(node.name, str): + # standard library nodes does not wrap + # the exception 'as' name in Name instance, so we use + # the ExceptHandler instance as reference point. + dnode = self.chains.setdefault(node, beniget.Def(node)) + self.set_definition(node.name, dnode) + if dnode not in self.locals[self._scopes[-1]]: + self.locals[self._scopes[-1]].append(dnode) + self.generic_visit(node) + + def visit_arg(self, node, skip_annotation=False): + dnode = self.chains.setdefault(node, beniget.Def(node)) + self.set_definition(node.arg, dnode) + if dnode not in self.locals[self._scopes[-1]]: + self.locals[self._scopes[-1]].append(dnode) + if node.annotation is not None and not skip_annotation: + self.visit(node.annotation) + return dnode + + def visit_ExtSlice(self, node): + dnode = self.chains.setdefault(node, beniget.Def(node)) + for elt in node.dims: + self.visit(elt).add_user(dnode) + return dnode + + def visit_Index(self, node): + # pretend Index does not exist + return self.visit(node.value) + + visit_NameConstant = visit_Num = visit_Str = \ + visit_Bytes = visit_Ellipsis = visit_Constant = beniget.DefUseChains.visit_Constant + +class UseDefChains(beniget.UseDefChains): + def __init__(self, defuses): + self.chains = {} + for chain in defuses.chains.values(): + if isinstance(chain.node, (ast.Name, ast.arg)): # the only change is here + self.chains.setdefault(chain.node, []) + for use in chain.users(): + self.chains.setdefault(use.node, []).append(chain) + + for chain in defuses._builtins.values(): + for use in chain.users(): + self.chains.setdefault(use.node, []).append(chain) diff --git a/requirements.txt b/requirements.txt index 6c43616..d02c8e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -gast ~= 0.5.0 +gast ~= 0.5.0 \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py index f1b4217..a0cf49a 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,6 @@ +""" +Each TestCase subclass should have it's standard library counterpart. +""" import tests.test_definitions import tests.test_chains import tests.test_capture diff --git a/tests/test_attributes.py b/tests/test_attributes.py index 3aa7eb3..be079f5 100644 --- a/tests/test_attributes.py +++ b/tests/test_attributes.py @@ -1,19 +1,20 @@ from unittest import TestCase from textwrap import dedent -import gast as ast -import beniget +import ast as _ast +import gast as _gast +from .test_chains import getDefUseChainsType -class Attributes(ast.NodeVisitor): +class Attributes(_ast.NodeVisitor): def __init__(self, module_node): - self.chains = beniget.DefUseChains() + self.chains = getDefUseChainsType(module_node)() self.chains.visit(module_node) self.attributes = set() self.users = set() def visit_ClassDef(self, node): for stmt in node.body: - if isinstance(stmt, ast.FunctionDef): + if isinstance(stmt, (_ast.FunctionDef, _gast.FunctionDef)): self_def = self.chains.chains[stmt.args.args[0]] self.users.update(use.node for use in self_def.users()) self.generic_visit(node) @@ -24,8 +25,9 @@ def visit_Attribute(self, node): class TestAttributes(TestCase): + ast = _gast def checkAttribute(self, code, extract, ref): - module = ast.parse(dedent(code)) + module = self.ast.parse(dedent(code)) c = Attributes(module) c.visit(extract(module)) self.assertEqual(c.attributes, ref) @@ -97,3 +99,6 @@ def bar(self, other): self = list return self.pop""" self.checkAttribute(code, lambda n: n.body[0], set()) + +class TestAttributesStdlib(TestAttributes): + ast = _ast \ No newline at end of file diff --git a/tests/test_capture.py b/tests/test_capture.py index ccd472a..c676a90 100644 --- a/tests/test_capture.py +++ b/tests/test_capture.py @@ -1,12 +1,14 @@ from unittest import TestCase from textwrap import dedent -import gast as ast -import beniget +import ast as _ast +import gast as _gast +from .test_chains import getDefUseChainsType -class Capture(ast.NodeVisitor): + +class Capture(_ast.NodeVisitor): def __init__(self, module_node): - self.chains = beniget.DefUseChains() + self.chains = getDefUseChainsType(module_node)() self.chains.visit(module_node) self.users = set() self.captured = set() @@ -17,15 +19,16 @@ def visit_FunctionDef(self, node): self.generic_visit(node) def visit_Name(self, node): - if isinstance(node.ctx, ast.Load): + if isinstance(node.ctx, (_ast.Load, _gast.Load)): if node not in self.users: # FIXME: IRL, should be the definition of this use self.captured.add(node.id) class TestCapture(TestCase): + ast = _gast def checkCapture(self, code, extract, ref): - module = ast.parse(dedent(code)) + module = self.ast.parse(dedent(code)) c = Capture(module) c.visit(extract(module)) self.assertEqual(c.captured, ref) @@ -43,3 +46,6 @@ def foo(x): def bar(x): return x""" self.checkCapture(code, lambda n: n.body[0].body[0], set()) + +class TestCaptureStdlib(TestCapture): + ast = _ast \ No newline at end of file diff --git a/tests/test_chains.py b/tests/test_chains.py index a388867..08633a7 100644 --- a/tests/test_chains.py +++ b/tests/test_chains.py @@ -1,14 +1,27 @@ from contextlib import contextmanager from unittest import TestCase, skipIf import unittest -import gast as ast -import beniget +import beniget.standard import io import sys +import ast as _ast +import gast as _gast +import gast.gast as _gast_module # Show full diff in unittest unittest.util._MAX_LENGTH=2000 +def replace_deprecated_names(out): + return out.replace( + 'Num', 'Constant' + ).replace( + 'Ellipsis', 'Constant' + ).replace( + 'Str', 'Constant' + ).replace( + 'Bytes', 'Constant' + ) + @contextmanager def captured_output(): new_out, new_err = io.StringIO(), io.StringIO() @@ -19,25 +32,53 @@ def captured_output(): finally: sys.stdout, sys.stderr = old_out, old_err +gast_nodes = tuple(getattr(_gast, t[0]) for t in _gast_module._nodes) -class TestDefUseChains(TestCase): - def checkChains(self, code, ref, strict=True): - class StrictDefUseChains(beniget.DefUseChains): +def getDefUseChainsType(node): + if isinstance(node, gast_nodes): + return beniget.DefUseChains + return beniget.standard.DefUseChains + +def getStrictDefUseChains(node): + class StrictDefUseChains(getDefUseChainsType(node)): def warn(self, msg, node): raise RuntimeError( "W: {} at {}:{}".format( msg, node.lineno, node.col_offset ) ) + return StrictDefUseChains + +def getUseDefChainsType(node): + if isinstance(node, gast_nodes): + return beniget.UseDefChains + return beniget.standard.UseDefChains - node = ast.parse(code) +class TestDefUseChains(TestCase): + ast = _gast + maxDiff = None + def checkChains(self, code, ref, strict=True): + node = self.ast.parse(code) if strict: - c = StrictDefUseChains() + c = getStrictDefUseChains(node)() else: - c = beniget.DefUseChains() + c = getDefUseChainsType(node)() + c.visit(node) - self.assertEqual(c.dump_chains(node), ref) + self.assertEqual(c.dump_chains(node), ref) return node, c + + def checkUseDefChains(self, code, ref, strict=True): + node = self.ast.parse(code) + if strict: + c = getStrictDefUseChains(node)() + else: + c = getDefUseChainsType(node)() + + c.visit(node) + cc = getUseDefChainsType(node)(c) + + self.assertEqual(str(cc), ref) def test_simple_expression(self): code = "a = 1; a + 2" @@ -398,8 +439,8 @@ def test_class_base(self): def test_def_used_in_self_default(self): code = "def foo(x:foo): return foo" - c = beniget.DefUseChains() - node = ast.parse(code) + node = self.ast.parse(code) + c = getDefUseChainsType(node)() c.visit(node) self.assertEqual(c.dump_chains(node), ["foo -> (foo -> ())"]) @@ -411,36 +452,36 @@ class mytype(str): x = x+1 # <- this triggers NameError: name 'x' is not defined return x ''' - c = beniget.DefUseChains() - node = ast.parse(code) + node = self.ast.parse(code) + c = getDefUseChainsType(node)() c.visit(node) self.assertEqual(c.dump_chains(node.body[0]), ['x -> (x -> ())', 'mytype -> ()']) def test_unbound_class_variable2(self): code = '''class A:\n a = 10\n def f(self):\n return a # a is not defined''' - c = beniget.DefUseChains() - node = ast.parse(code) + node = self.ast.parse(code) + c = getDefUseChainsType(node)() c.visit(node) self.assertEqual(c.dump_chains(node.body[0]), ['a -> ()', 'f -> ()']) def test_unbound_class_variable3(self): code = '''class A:\n a = 10\n class I:\n b = a + 1 # a is not defined''' - c = beniget.DefUseChains() - node = ast.parse(code) + node = self.ast.parse(code) + c = getDefUseChainsType(node)() c.visit(node) self.assertEqual(c.dump_chains(node.body[0]), ['a -> ()', 'I -> ()']) def test_unbound_class_variable4(self): code = '''class A:\n a = 10\n f = lambda: a # a is not defined''' - c = beniget.DefUseChains() - node = ast.parse(code) + node = self.ast.parse(code) + c = getDefUseChainsType(node)() c.visit(node) self.assertEqual(c.dump_chains(node.body[0]), ['a -> ()', 'f -> ()']) def test_unbound_class_variable5(self): code = '''class A:\n a = 10\n b = [a for _ in range(10)] # a is not defined''' - c = beniget.DefUseChains() - node = ast.parse(code) + node = self.ast.parse(code) + c = getDefUseChainsType(node)() c.visit(node) self.assertEqual(c.dump_chains(node.body[0]), ['a -> ()', 'b -> ()']) @@ -463,14 +504,14 @@ def count(self) -> mytype: # this should trigger unbound identifier def c(x) -> mytype(): # this one shouldn't ... ''' - c = beniget.DefUseChains() - node = ast.parse(code) + node = self.ast.parse(code) + c = getDefUseChainsType(node)() c.visit(node) self.assertEqual(c.dump_chains(node.body[0].body[0]), ['mytype -> (mytype -> (Call -> ()))']) def check_message(self, code, expected_messages, filename=None): - node = ast.parse(code) - c = beniget.DefUseChains(filename) + node = self.ast.parse(code) + c = getDefUseChainsType(node)(filename) with captured_output() as (out, err): c.visit(node) @@ -524,8 +565,8 @@ class Visitor: def visit_Name(self, node):pass visit_Attribute = visit_Name ''' - node = ast.parse(code) - c = beniget.DefUseChains() + node = self.ast.parse(code) + c = getDefUseChainsType(node)() c.visit(node) self.assertEqual(c.dump_chains(node), ['visit_Name -> ()', 'Visitor -> ()']) @@ -540,8 +581,8 @@ class Attr(object):pass class Visitor: class Attr(Attr):pass ''' - node = ast.parse(code) - c = beniget.DefUseChains() + node = self.ast.parse(code) + c = getDefUseChainsType(node)() c.visit(node) self.assertEqual(c.dump_chains(node), ['Attr -> (Attr -> (Attr -> ()))', @@ -593,8 +634,8 @@ class Visitor: def f(): return f() def visit_Name(self, node:Thing, fn:f):... ''' - node = ast.parse(code) - c = beniget.DefUseChains() + node = self.ast.parse(code) + c = getDefUseChainsType(node)() c.visit(node) self.assertEqual(c.dump_chains(node), ['Thing -> ()', @@ -611,8 +652,8 @@ def visit_Attribute(self, node):pass class Visitor: visit_Attribute = visit_Attribute ''' - node = ast.parse(code) - c = beniget.DefUseChains() + node = self.ast.parse(code) + c = getDefUseChainsType(node)() c.visit(node) self.assertEqual(c.dump_chains(node), ['visit_Attribute -> (visit_Attribute -> ())', @@ -919,7 +960,7 @@ def Thing(self, y:Type[Thing]) -> Thing: # this is OK, and it links to the top l "W: unbound identifier 'D'", ] - assert len(produced_messages) == len(expected_warnings), len(produced_messages) + assert len(produced_messages) == len(expected_warnings), produced_messages assert all(any(w in pw for pw in produced_messages) for w in expected_warnings) # locals of C @@ -1064,8 +1105,8 @@ class mytype2(int): fn = outer() ''' - mod = ast.parse(code) - chains = beniget.DefUseChains('test') + mod = self.ast.parse(code) + chains = getDefUseChainsType(mod)('test') with captured_output() as (out, err): chains.visit(mod) @@ -1110,8 +1151,8 @@ class mytype2(int): # to the inner classes. def test_lookup_scopes(self): - from beniget.beniget import _get_lookup_scopes - mod, fn, cls, lambd, gen, comp = ast.Module(), ast.FunctionDef(), ast.ClassDef(), ast.Lambda(), ast.GeneratorExp(), ast.DictComp() + from beniget.beniget import _get_lookup_scopes, def695 + mod, fn, cls, lambd, gen, comp, typeparams = self.ast.Module(), self.ast.FunctionDef(), self.ast.ClassDef(), self.ast.Lambda(), self.ast.GeneratorExp(), self.ast.DictComp(), def695(body=[], d=self.ast.FunctionDef()) assert _get_lookup_scopes((mod, fn, fn, fn, cls)) == [mod, fn, fn, fn, cls] assert _get_lookup_scopes((mod, fn, fn, fn, cls, fn)) == [mod, fn, fn, fn, fn] assert _get_lookup_scopes((mod, cls, fn)) == [mod, fn] @@ -1121,6 +1162,11 @@ def test_lookup_scopes(self): assert _get_lookup_scopes((mod, fn)) == [mod, fn] assert _get_lookup_scopes((mod, cls)) == [mod, cls] assert _get_lookup_scopes((mod,)) == [mod] + assert _get_lookup_scopes((mod, typeparams)) == [mod, typeparams] + assert _get_lookup_scopes((mod, typeparams, typeparams)) == [mod, typeparams, typeparams] + assert _get_lookup_scopes((mod, cls, typeparams)) == [mod, cls, typeparams] + assert _get_lookup_scopes((mod, cls, cls, typeparams)) == [mod, cls, typeparams] + assert _get_lookup_scopes((mod, cls, cls, typeparams, fn)) == [mod, typeparams, fn] with self.assertRaises(ValueError, msg='invalid heads: must include at least one element'): _get_lookup_scopes(()) @@ -1237,28 +1283,500 @@ class A: strict=False ) - @skipIf(sys.version_info.major < 3, "Python 3 syntax") def test_annotation_def_is_not_assign_target(self): code = 'from typing import Optional; var:Optional' self.checkChains(code, ['Optional -> (Optional -> ())', 'var -> ()']) + + def test_pep563_disallowed_expressions(self): + cases = [ + "def func(a: (yield)) -> ...: ...", + "def func(a: ...) -> (yield from []): ...", + "def func(*a: (y := 3)) -> ...: ...", + "def func(**a: (await 42)) -> ...: ...", + + "x: (yield) = True", + "x: (yield from []) = True", + "x: (y := 3) = True", + "x: (await 42) = True",] + + for code in cases: + code = f'from __future__ import annotations\n' + code + with self.subTest(code): + self.check_message(code, ['cannot be used in annotation-like scopes']) -class TestUseDefChains(TestCase): - def checkChains(self, code, ref): - class StrictDefUseChains(beniget.DefUseChains): - def unbound_identifier(self, name, node): - raise RuntimeError( - "W: unbound identifier '{}' at {}:{}".format( - name, node.lineno, node.col_offset - ) - ) + for code in cases: + with self.subTest(code): + # From python 3.13, this should generate the same error. + self.check_message(code, []) + + # PEP-695 test cases taken from https://github.com/python/cpython/pull/103764/files + # but also https://github.com/python/cpython/pull/109297/files and + # https://github.com/python/cpython/pull/109123/files + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_typeparams_name_collision_01(self): + # The following code triggers syntax error at runtime. + # But detecting this would required beniget to keep track of the + # names of type parameters and validate them like we validate comprehensions or annotations. + # We don't do it for functions currently, so it doesn't make sens to do it for + # type parameters at this time. + code = """def func[**A, A](): ...""" + self.checkChains(code, ['func -> ()']) + self.checkUseDefChains(code, 'func <- {A, A}') + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_typeparams_name_non_collision_02(self): + code = """def func[A](A): return A""" + self.checkChains(code, ['func -> ()']) + self.checkUseDefChains(code, 'A <- {A}, A <- {}, func <- {A}') + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_typeparams_name_non_collision_03(self): + code = """def func[A](*A): return A""" + self.checkChains(code, ['func -> ()']) + self.checkUseDefChains(code, 'A <- {A}, A <- {}, func <- {A}') + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_typeparams_name_non_collision_04(self): + # Mangled names should not cause a conflict. + code = """class ClassA:\n def func[__A](self, __A): return __A""" + self.checkChains(code, ['ClassA -> ()']) + self.checkUseDefChains(code, '__A <- {__A}, __A <- {}, func <- {__A}, self <- {}') + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_typeparams_name_non_collision_05(self): + code = """class ClassA:\n def func[_ClassA__A](self, __A): return __A""" + self.checkChains(code, ['ClassA -> ()']) + self.checkUseDefChains(code, '__A <- {__A}, __A <- {}, func <- {_ClassA__A}, self <- {}') + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_typeparams_name_non_collision_06(self): + code = """class ClassA[X]:\n def func(self, X): return X""" + self.checkChains(code, ['ClassA -> ()']) + self.checkUseDefChains(code, 'ClassA <- {X}, X <- {X}, X <- {}, self <- {}') + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_typeparams_name_non_collision_07(self): + code = """class ClassA[X]:\n def func(self):\n X = 1;return X""" + self.checkChains(code, ['ClassA -> ()']) + self.checkUseDefChains(code, 'ClassA <- {X}, X <- {X}, X <- {}, self <- {}') + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_typeparams_name_non_collision_08(self): + code = """class ClassA[X]:\n def func(self): return [X for X in [1, 2]]""" + self.checkChains(code, ['ClassA -> ()']) + self.checkUseDefChains(code, 'ClassA <- {X}, List <- {Constant, Constant}, ListComp <- {X, comprehension}, X <- {X}, X <- {}, comprehension <- {List}, self <- {}') + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_typeparams_name_non_collision_09(self): + code = """class ClassA[X]:\n def func[X](self):...""" + self.checkChains(code, ['ClassA -> ()']) + self.checkUseDefChains(code, 'ClassA <- {X}, func <- {X}, self <- {}') + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_typeparams_name_non_collision_10(self): + code = """class ClassA[X]:\n X: int""" + self.checkChains(code, ['ClassA -> ()']) + self.checkUseDefChains(code, 'ClassA <- {X}, X <- {}, int <- {type}') + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_typeparams_name_non_collision_13(self): + code = """X = 1\ndef outer():\n def inner[X]():\n global X;X=2\n return inner""" + node, chains = self.checkChains(code, ['X -> ()', 'outer -> ()']) + self.assertEqual(chains.dump_chains(node.body[-1]), ['inner -> (inner -> ())']) + self.checkUseDefChains(code, 'X <- {}, X <- {}, inner <- {X}, inner <- {inner}') + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_typeparams_disallowed_expressions(self): + cases = ["type X = (yield)", + "type X = (yield from x)", + "type X = (await 42)", + "async def f(): type X = (yield)", + "type X = (y := 3)", + "class X[T: (yield)]: pass", + "class X[T: (yield from [])]: pass", + "class X[T: (await 42)]: pass", + "class X[T: (y := 3)]: pass", + "class X[T](y := list[T]): pass", + "def f[T](y: (x := list[T])): pass",] + + for code in cases: + with self.subTest(code): + self.check_message(code, ['cannot be used in annotation-like scopes']) + + for code in cases: + code = f'from __future__ import annotations\n' + code + with self.subTest(code): + self.check_message(code, ['cannot be used in annotation-like scopes']) + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_type_alias_name_collision_01(self): + # syntax error at runtime "duplicate type parameter 'A'" + code = """type TA1[A, **A] = None""" + self.checkChains(code, ['TA1 -> ()']) + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_type_alias_name_non_collision_02(self): + code = """type TA1[A] = lambda A: A""" + self.checkChains(code, ['TA1 -> ()']) + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_type_alias_name_non_collision_03(self): + code = """class Outer[A]:\n type TA1[A] = None""" + self.checkChains(code, ['Outer -> ()']) + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_type_alias_access_01(self): + code = "type TA1[A, B] = dict[A, B]" + self.checkChains(code, ['TA1 -> ()']) + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_type_alias_access_02(self): + code = """type TA1[A, B] = TA1[A, B] | int""" + self.checkChains(code, ['TA1 -> (TA1 -> (Subscript -> (BinOp -> ())))']) + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_type_alias_access_03(self): + code = """class Outer[A]:\n def inner[B](self):\n type TA1[C] = TA1[A, B] | int; return TA1""" + self.checkChains(code, ['Outer -> ()']) + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_scopes01(self): + code = """\ +from typing import Sequence + +# The following generates no compiler error, but a type checker +# should generate an error because an upper bound type must be concrete, +# and ``Sequence[S]`` is generic. Future extensions to the type system may +# eliminate this limitation. +class ClassA[S, T: Sequence[S]]: ... + +# The following generates no compiler error, because the bound for ``S`` +# is lazily evaluated. However, type checkers should generate an error. +class ClassB[S: Sequence[T], T]: ... +""" + self.checkChains(code, ['Sequence -> (Sequence -> (Subscript -> ()), Sequence -> (Subscript -> ()))', + 'ClassA -> ()', + 'ClassB -> ()']) + self.checkUseDefChains(code, 'ClassA <- {S, T}, ClassB <- {S, T}, S <- {S}, ' + 'Sequence <- {Sequence}, Sequence <- {Sequence}, ' + 'Subscript <- {S, Sequence}, Subscript <- {Sequence, T}, T <- {T}') - node = ast.parse(code) - c = StrictDefUseChains() - c.visit(node) - cc = beniget.UseDefChains(c) + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_scopes02(self): + code = """\ +from x import BaseClass, dec, Foo - self.assertEqual(str(cc), ref) +class ClassA[T](BaseClass[T], param = Foo[T]): ... # OK + +print(T) # Runtime error: 'T' is not defined + +@dec(Foo[T]) # Runtime error: 'T' is not defined +class ClassA[T]: ... +""" + self.check_message(code, ["W: unbound identifier 'T' at :5:6", + "W: unbound identifier 'T' at :7:9"]) + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_scopes03(self): + code = """\ +from x import dec +def func1[T](a: T) -> T: ... # OK + +print(T) # Runtime error: 'T' is not defined + +def func2[T](a = list[T]): ... # Runtime error: 'T' is not defined + +@dec(list[T]) # Runtime error: 'T' is not defined +def func3[T](): ... +""" + self.check_message(code, ["W: unbound identifier 'T' at :4:6", + "W: unbound identifier 'T' at :6:22", + "W: unbound identifier 'T' at :8:10"]) + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_scopes04(self): + + code = """\ +S = 0 + +def outer1[S](): + S = 1 + T = 1 + + def outer2[T](): + + def inner1(): + nonlocal S # OK because it binds variable S from outer1 + print(S) + nonlocal T # Syntax error: nonlocal binding not allowed for type parameter + print(T) + + def inner2(): + global S # OK because it binds variable S from global scope + print(S) +""" + self.check_message(code, ['W: names defined in annotation scopes cannot be rebound with nonlocal statements at :12:12']) + self.checkChains(code, ['S -> (S -> (Call -> ()))', 'outer1 -> ()'], strict=False) + self.checkUseDefChains(code, 'Call <- {S, print}, Call <- {S, print}, Call <- {T, print}, S <- {S}, S <- {S}, S <- {}, S <- {}, T <- {T}, T <- {}, outer1 <- {S}, outer2 <- {T}, print <- {builtin_function_or_method}, print <- {builtin_function_or_method}, print <- {builtin_function_or_method}', strict=False) + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_scopes04bis(self): + code = '''\ +def outer1(): + def outer2[T](): + def inner1(): + print(T) + inner1() + outer2() +outer1() +''' + self.check_message(code, []) + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_scopes05(self): + code = """\ +from typing import Sequence + +class Outer: + class Private: + pass + + # If the type parameter scope was like a traditional scope, + # the base class 'Private' would not be accessible here. + class Inner[T](Private, Sequence[T]): + pass + + # Likewise, 'Inner' would not be available in these type annotations. + def method1[T](self, a: Inner[T]) -> Inner[T]: + return a +""" + self.checkChains(code, ['Sequence -> (Sequence -> (Subscript -> (Inner -> (Inner -> (Subscript -> ()), Inner -> (Subscript -> ())))))', + 'Outer -> ()']) + self.checkUseDefChains(code, 'Inner <- {Inner}, Inner <- {Inner}, Inner <- {Private, Subscript, T}, Private <- {Private}, ' + 'Sequence <- {Sequence}, Subscript <- {Inner, T}, Subscript <- {Inner, T}, Subscript <- {Sequence, T}, ' + 'T <- {T}, T <- {T}, T <- {T}, a <- {a}, a <- {}, method1 <- {T}, self <- {}') + + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_scopes06(self): + code = """\ +from typing import Sequence +from x import decorator + +T = 0 + +@decorator(T) # Argument expression `T` evaluates to 0 +class ClassA[T](Sequence[T]): + T = 1 + + # All methods below should result in a type checker error + # "type parameter 'T' already in use" because they are using the + # type parameter 'T', which is already in use by the outer scope + # 'ClassA'. + def method1[T](self): + ... + + def method2[T](self, x = T): # Parameter 'x' gets default value of 1 + ... + + def method3[T](self, x: T): # Parameter 'x' has type T (scoped to method3) + ... + +""" + self.checkChains(code, ['Sequence -> (Sequence -> (Subscript -> (ClassA -> ())))', + 'decorator -> (decorator -> (Call -> (ClassA -> ())))', + 'T -> (T -> (Call -> (ClassA -> ())))', + 'ClassA -> ()']) + self.checkUseDefChains(code, 'Call <- {T, decorator}, ClassA <- {Call, Subscript, T}, ' + 'Sequence <- {Sequence}, Subscript <- {Sequence, T}, T <- {T}, T <- {T}, ' + 'T <- {T}, T <- {T}, T <- {}, T <- {}, decorator <- {decorator}, method1 <- {T}, ' + 'method2 <- {T, T}, method3 <- {T}, self <- {}, self <- {}, self <- {}, x <- {}, x <- {}') + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_scopes07(self): + code = """\ +T = 0 + +# T refers to the global variable +print(T) # Prints 0 + +class Outer[T]: + T = 1 + + # T refers to the local variable scoped to class 'Outer' + print(T) # Prints 1 + + class Inner1: + T = 2 + + # T refers to the local type variable within 'Inner1' + print(T) # Prints 2 + + def inner_method(self): + # T refers to the type parameter scoped to class 'Outer'; + # If 'Outer' did not use the new type parameter syntax, + # this would instead refer to the global variable 'T' + print(T) # Prints 'T' + + def outer_method(self): + T = 3 + + # T refers to the local variable within 'outer_method' + print(T) # Prints 3 + + def inner_func(): + # T refers to the variable captured from 'outer_method' + print(T) # Prints 3 +""" + self.checkChains(code, ['T -> (T -> (Call -> ()))', 'Outer -> ()']) + self.checkUseDefChains(code, 'Call <- {T, print}, Call <- {T, print}, Call <- {T, print}, Call <- {T, print}, ' + 'Call <- {T, print}, Call <- {T, print}, Outer <- {T}, T <- {T}, T <- {T}, T <- {T}, ' + 'T <- {T}, T <- {T}, T <- {T}, T <- {}, T <- {}, T <- {}, T <- {}, print <- {builtin_function_or_method}, ' + 'print <- {builtin_function_or_method}, print <- {builtin_function_or_method}, ' + 'print <- {builtin_function_or_method}, print <- {builtin_function_or_method}, ' + 'print <- {builtin_function_or_method}, self <- {}, self <- {}') + + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_scopes08(self): + code = '''\ +from x import decorator +T = 1 +@decorator +def f[decorator, T: int, U: (int, str), *Ts, **P]( + y: U, + x: T = T, # default values are evaluated outside the def695 scope + *args: *Ts, + **kwargs: P.kwargs, +) -> T: + return x +''' + self.checkChains(code, ['decorator -> (decorator -> ())', + 'T -> (T -> (f -> ()))', + 'f -> ()']) + self.checkUseDefChains(code, 'Attribute <- {P}, P <- {P}, Starred <- {Ts}, T <- {T}, T <- {T}, T <- {T}, T <- {}, ' + 'Ts <- {Ts}, Tuple <- {int, str}, U <- {U}, args <- {}, decorator <- {decorator}, ' + 'f <- {P, T, T, Ts, U, decorator}, int <- {type}, int <- {type}, kwargs <- {}, ' + 'str <- {type}, x <- {x}, x <- {}, y <- {}') + + @skipIf(sys.version_info < (3,12), "Python 3.12 syntax") + def test_pep695_scopes09(self): + code = '''\ +from x import decorator +@decorator +class B[decorator](object): + print(decorator) +''' + self.checkChains(code, ['decorator -> (decorator -> (B -> ()))', + 'B -> ()']) + + @skipIf(sys.version_info < (3,10), "Python 3.10 syntax") + def test_match_value(self): + code = ''' +command = 123 +match command: + case 123 as b: + b+=1 + ''' + self.checkChains(code, ['command -> (command -> ())', + 'b -> (b -> ())']) + + @skipIf(sys.version_info < (3,10), "Python 3.10 syntax") + def test_match_list(self): + code = ''' +command = 'go there' +match command.split(): + case ["go", direction]: + print(direction) + case _: + raise ValueError("Sorry") + ''' + self.checkChains(code, ['command -> (command -> (Attribute -> (Call -> ())))', + 'direction -> (MatchSequence -> (), direction -> (Call -> ()))']) + + @skipIf(sys.version_info < (3,10), "Python 3.10 syntax") + def test_match_list_star(self): + code = ''' +command = 'drop' +match command.split(): + case ["go", direction]: ... + case ["drop", *objects]: + print(objects) + ''' + self.checkChains(code, ['command -> (command -> (Attribute -> (Call -> ())))', + 'direction -> (MatchSequence -> ())', + 'objects -> (MatchSequence -> (), objects -> (Call -> ()))']) + + @skipIf(sys.version_info < (3,10), "Python 3.10 syntax") + def test_match_dict(self): + code = ''' +ui = object() +action = dict(text='') +match action: + case {"text": str(message), "color": str(c), **rest}: + ui.set_text_color(c) + ui.display(message) + print(rest) + case {"sleep": float(duration)}: + ui.wait(duration) + case {"sound": str(url), "format": "ogg"}: + ui.play(url) + case {"sound": _, "format": _}: + raise ValueError("Unsupported audio format") +print(c) + ''' + self.checkChains(code, ['ui -> (ui -> (Attribute -> (Call -> ())), ui -> (Attribute -> (Call -> ())), ui -> (Attribute -> (Call -> ())), ui -> (Attribute -> (Call -> ())))', + 'action -> (action -> ())', + 'message -> (MatchClass -> (rest -> (rest -> (Call -> ()))), message -> (Call -> ()))', + 'c -> (MatchClass -> (rest -> (rest -> (Call -> ()))), c -> (Call -> ()), c -> (Call -> ()))', + 'rest -> (rest -> (Call -> ()))', + 'duration -> (MatchClass -> (MatchMapping -> ()), duration -> (Call -> ()))', + 'url -> (MatchClass -> (MatchMapping -> ()), url -> (Call -> ()))']) + + @skipIf(sys.version_info < (3,10), "Python 3.10 syntax") + def test_match_class_rebinds_attrs(self): + + code = ''' +from dataclasses import dataclass + +@dataclass +class Point: + x: int + y: int + +point = Point(-2,1) +match point: + case Point(x=0, y=0): + print("Origin") + case Point(x=0, y=y): + print(f"Y={y}") + case Point(x=x, y=0): + print(f"X={x}") + case Point(x=x, y=y): + print("Somewhere else") + case _: + print("Not a point") +print(x, y) + ''' + self.checkChains( + code, ['dataclass -> (dataclass -> (Point -> (Point -> (Call -> ()), Point -> (MatchClass -> ()), Point -> (MatchClass -> ()), Point -> (MatchClass -> ()), Point -> (MatchClass -> ()))))', + 'Point -> (Point -> (Call -> ()), Point -> (MatchClass -> ()), Point -> (MatchClass -> ()), Point -> (MatchClass -> ()), Point -> (MatchClass -> ()))', + 'point -> (point -> ())', + 'y -> (MatchClass -> (), y -> (FormattedValue -> (JoinedStr -> (Call -> ()))), y -> (Call -> ()))', + 'x -> (MatchClass -> (), x -> (FormattedValue -> (JoinedStr -> (Call -> ()))), x -> (Call -> ()))', + 'x -> (MatchClass -> (), x -> (Call -> ()))', + 'y -> (MatchClass -> (), y -> (Call -> ()))']) + + +class TestDefUseChainsStdlib(TestDefUseChains): + ast = _ast + + +class TestUseDefChains(TestCase): + ast = _gast + checkChains = TestDefUseChains.checkUseDefChains def test_simple_expression(self): code = "a = 1; a" @@ -1267,3 +1785,7 @@ def test_simple_expression(self): def test_call(self): code = "from foo import bar; bar(1, 2)" self.checkChains(code, "Call <- {Constant, Constant, bar}, bar <- {bar}") + +class TestUseDefChainsStdlib(TestDefUseChains): + ast = _ast + diff --git a/tests/test_definitions.py b/tests/test_definitions.py index e9f956a..c5ee0c9 100644 --- a/tests/test_definitions.py +++ b/tests/test_definitions.py @@ -1,23 +1,16 @@ from textwrap import dedent from unittest import TestCase -import gast as ast -import beniget +import gast as _gast +import ast as _ast import sys - -class StrictDefUseChains(beniget.DefUseChains): - def unbound_identifier(self, name, node): - raise RuntimeError( - "W: unbound identifier '{}' at {}:{}".format( - name, node.lineno, node.col_offset - ) - ) - +from .test_chains import getStrictDefUseChains class TestGlobals(TestCase): + ast = _gast def checkGlobals(self, code, ref): - node = ast.parse(code) - c = StrictDefUseChains() + node = self.ast.parse(code) + c = getStrictDefUseChains(node)() c.visit(node) self.assertEqual(c.dump_definitions(node), ref) @@ -278,13 +271,16 @@ def testGlobalLambda(self): code = "lambda x: x" self.checkGlobals(code, []) +class TestGlobalsStdlib(TestGlobals): + ast = _ast class TestClasses(TestCase): + ast = _gast def checkClasses(self, code, ref): - node = ast.parse(code) - c = StrictDefUseChains() + node = self.ast.parse(code) + c = getStrictDefUseChains(node)() c.visit(node) - classes = [n for n in node.body if isinstance(n, ast.ClassDef)] + classes = [n for n in node.body if isinstance(n, self.ast.ClassDef)] assert len(classes) == 1, "only one top-level function per test case" cls = classes[0] self.assertEqual(c.dump_definitions(cls), ref) @@ -293,13 +289,16 @@ def test_class_method_assign(self): code = "class C:\n def foo(self):pass\n bar = foo" self.checkClasses(code, ["bar", "foo"]) +class TestClassesStdlib(TestClasses): + ast = _ast class TestLocals(TestCase): + ast = _gast def checkLocals(self, code, ref): - node = ast.parse(dedent(code)) - c = StrictDefUseChains() + node = self.ast.parse(dedent(code)) + c = getStrictDefUseChains(node)() c.visit(node) - functions = [n for n in node.body if isinstance(n, ast.FunctionDef)] + functions = [n for n in node.body if isinstance(n, self.ast.FunctionDef)] assert len(functions) == 1, "only one top-level function per test case" f = functions[0] self.assertEqual(c.dump_definitions(f), ref) @@ -394,15 +393,20 @@ def foo(a): else: b = a""" self.checkLocals(code, ["a", "b"]) +class TestLocalsStdlib(TestLocals): + ast = _ast + class TestDefIsLive(TestCase): + ast = _gast + def checkLocals(self, c, node, ref, only_live=False): self.assertEqual(sorted(c._dump_locals(node, only_live=only_live)), sorted(ref)) def checkLiveLocals(self, code, livelocals, locals): - node = ast.parse(dedent(code)) - c = StrictDefUseChains() + node = self.ast.parse(dedent(code)) + c = getStrictDefUseChains(node)() c.visit(node) self.checkLocals(c, node, locals) self.checkLocals(c, node, livelocals, only_live=True) @@ -552,3 +556,5 @@ def test_more_loops(self): self.checkLiveLocals(code, ['b:2,6', 'v:9,10,4', 'k:10,13'], ['b:2,6', 'v:9,10,4', 'k:10,13']) +class TestDefIsLiveStdlib(TestDefIsLive): + ast = _ast \ No newline at end of file diff --git a/tox.ini b/tox.ini index 3e9111b..6ae1ab3 100644 --- a/tox.ini +++ b/tox.ini @@ -2,5 +2,6 @@ envlist = py27,py36,py37,py38,py39,py310,py311,py312 [testenv] deps = + git+https://github.com/serge-sans-paille/gast.git pytest -commands=pytest beniget/ tests/ --doctest-modules \ No newline at end of file +commands=pytest beniget/ tests/ --doctest-modules