diff --git a/beniget/beniget.py b/beniget/beniget.py index 50f035b..908ebf0 100644 --- a/beniget/beniget.py +++ b/beniget/beniget.py @@ -1,5 +1,5 @@ from collections import defaultdict, OrderedDict -from contextlib import contextmanager +from contextlib import contextmanager, suppress import sys import gast as ast @@ -1060,8 +1060,9 @@ def visit_ListComp(self, node): dnode = self.chains.setdefault(node, Def(node)) with self.CompScopeContext(node): - for comprehension in node.generators: - self.visit(comprehension).add_user(dnode) + for i, comprehension in enumerate(node.generators): + self.visit_comprehension(comprehension, + is_nested=i!=0).add_user(dnode) self.visit(node.elt).add_user(dnode) return dnode @@ -1072,8 +1073,9 @@ def visit_DictComp(self, node): dnode = self.chains.setdefault(node, Def(node)) with self.CompScopeContext(node): - for comprehension in node.generators: - self.visit(comprehension).add_user(dnode) + for i, comprehension in enumerate(node.generators): + self.visit_comprehension(comprehension, + is_nested=i!=0).add_user(dnode) self.visit(node.key).add_user(dnode) self.visit(node.value).add_user(dnode) @@ -1220,9 +1222,19 @@ def visit_Slice(self, node): # misc - def visit_comprehension(self, node): + def visit_comprehension(self, node, is_nested:bool): dnode = self.chains.setdefault(node, Def(node)) - self.visit(node.iter).add_user(dnode) + if not is_nested: + # There's one part of a comprehension or generator expression that executes in the surrounding scope, + # regardless of Python version: it's the expression for the outermost iterable. + scope_ctx = self.SwitchScopeContext(self._definitions[:-1], self._scopes[:-1], + self._scope_depths[:-1], self._precomputed_locals[:-1]) + else: + # If a comprehension has multiple for clauses, + # the iterables for inner for clauses are evaluated in the comprehension's scope: + scope_ctx = suppress() + with scope_ctx: + self.visit(node.iter).add_user(dnode) self.visit(node.target) for if_ in node.ifs: self.visit(if_).add_user(dnode) diff --git a/tests/test_chains.py b/tests/test_chains.py index e9d33f8..297bd00 100644 --- a/tests/test_chains.py +++ b/tests/test_chains.py @@ -597,6 +597,34 @@ def pop(): 'pop -> (pop -> (Call -> ()))', 'cos -> (cos -> (Call -> ()))' ]) + + def test_class_scope_comprehension(self): + code = ''' +class Cls: + foo = b'1', + [_ for _ in foo] + {_ for _ in foo} + (_ for _ in foo) + {_:1 for _ in foo} +''' + node, chains = self.checkChains(code, ['Cls -> ()']) + self.assertEqual(chains.dump_chains(node.body[0]), + ['foo -> (' + 'foo -> (comprehension -> (ListComp -> ())), ' + 'foo -> (comprehension -> (SetComp -> ())), ' + 'foo -> (comprehension -> (GeneratorExp -> ())), ' + 'foo -> (comprehension -> (DictComp -> ())))']) + + def test_class_scope_comprehension_invalid(self): + code = ''' +class Foo: + x = 5 + y = [x for i in range(1)] + z = [i for i in range(1) for j in range(x)] +''' + self.check_message(code, ["W: unbound identifier 'x' at test:4:9", + "W: unbound identifier 'x' at test:5:44"], 'test') + @skipIf(sys.version_info < (3, 8), 'Python 3.8 syntax') def test_named_expr_simple(self):