Skip to content

Commit

Permalink
Fix bug in Bug in class scope + comprehension # 65
Browse files Browse the repository at this point in the history
  • Loading branch information
tristanlatr committed Aug 7, 2023
1 parent c1c6ae3 commit 301b550
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 7 deletions.
26 changes: 19 additions & 7 deletions beniget/beniget.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict, OrderedDict
from contextlib import contextmanager
from contextlib import contextmanager, suppress
import sys

import gast as ast
Expand Down Expand Up @@ -1060,8 +1060,9 @@ def visit_ListComp(self, node):
dnode = self.chains.setdefault(node, Def(node))

with self.CompScopeContext(node):
for comprehension in node.generators:
self.visit(comprehension).add_user(dnode)
for i, comprehension in enumerate(node.generators):
self.visit_comprehension(comprehension,
is_nested=i!=0).add_user(dnode)
self.visit(node.elt).add_user(dnode)

return dnode
Expand All @@ -1072,8 +1073,9 @@ def visit_DictComp(self, node):
dnode = self.chains.setdefault(node, Def(node))

with self.CompScopeContext(node):
for comprehension in node.generators:
self.visit(comprehension).add_user(dnode)
for i, comprehension in enumerate(node.generators):
self.visit_comprehension(comprehension,
is_nested=i!=0).add_user(dnode)
self.visit(node.key).add_user(dnode)
self.visit(node.value).add_user(dnode)

Expand Down Expand Up @@ -1220,9 +1222,19 @@ def visit_Slice(self, node):

# misc

def visit_comprehension(self, node):
def visit_comprehension(self, node, is_nested:bool):
dnode = self.chains.setdefault(node, Def(node))
self.visit(node.iter).add_user(dnode)
if not is_nested:
# There's one part of a comprehension or generator expression that executes in the surrounding scope,
# regardless of Python version: it's the expression for the outermost iterable.
scope_ctx = self.SwitchScopeContext(self._definitions[:-1], self._scopes[:-1],
self._scope_depths[:-1], self._precomputed_locals[:-1])
else:
# If a comprehension has multiple for clauses,
# the iterables for inner for clauses are evaluated in the comprehension's scope:
scope_ctx = suppress()
with scope_ctx:
self.visit(node.iter).add_user(dnode)
self.visit(node.target)
for if_ in node.ifs:
self.visit(if_).add_user(dnode)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,34 @@ def pop():
'pop -> (pop -> (Call -> ()))',
'cos -> (cos -> (Call -> ()))'
])

def test_class_scope_comprehension(self):
code = '''
class Cls:
foo = b'1',
[_ for _ in foo]
{_ for _ in foo}
(_ for _ in foo)
{_:1 for _ in foo}
'''
node, chains = self.checkChains(code, ['Cls -> ()'])
self.assertEqual(chains.dump_chains(node.body[0]),
['foo -> ('
'foo -> (comprehension -> (ListComp -> ())), '
'foo -> (comprehension -> (SetComp -> ())), '
'foo -> (comprehension -> (GeneratorExp -> ())), '
'foo -> (comprehension -> (DictComp -> ())))'])

def test_class_scope_comprehension_invalid(self):
code = '''
class Foo:
x = 5
y = [x for i in range(1)]
z = [i for i in range(1) for j in range(x)]
'''
self.check_message(code, ["W: unbound identifier 'x' at test:4:9",
"W: unbound identifier 'x' at test:5:44"], 'test')


@skipIf(sys.version_info < (3, 8), 'Python 3.8 syntax')
def test_named_expr_simple(self):
Expand Down

0 comments on commit 301b550

Please sign in to comment.