diff --git a/beniget/beniget.py b/beniget/beniget.py index c964c50..034842b 100644 --- a/beniget/beniget.py +++ b/beniget/beniget.py @@ -1145,7 +1145,13 @@ def visit_Subscript(self, node): self.visit(node.slice).add_user(dnode) return dnode - visit_Starred = visit_Await + def visit_Starred(self, node): + if isinstance(node.ctx, ast.Store): + return self.visit(node.value) + else: + dnode = self.chains.setdefault(node, Def(node)) + self.visit(node.value).add_user(dnode) + return dnode def visit_NamedExpr(self, node): dnode = self.chains.setdefault(node, Def(node)) @@ -1195,7 +1201,7 @@ def visit_Destructured(self, node): tmp_store, elt.ctx = elt.ctx, tmp_store self.visit(elt) tmp_store, elt.ctx = elt.ctx, tmp_store - elif isinstance(elt, ast.Subscript): + elif isinstance(elt, (ast.Subscript, ast.Starred)): self.visit(elt) elif isinstance(elt, (ast.List, ast.Tuple)): self.visit_Destructured(elt) diff --git a/tests/test_chains.py b/tests/test_chains.py index 288ccc6..6186d46 100644 --- a/tests/test_chains.py +++ b/tests/test_chains.py @@ -75,6 +75,11 @@ def test_type_destructuring_for(self): code = "for a, b in ((1,2), (3,4)): a" self.checkChains(code, ["a -> (a -> ())", "b -> ()"]) + if sys.version_info.major >= 3: + def test_type_destructuring_starred(self): + code = "a, *b = range(2); b" + self.checkChains(code, ['a -> ()', 'b -> (b -> ())']) + def test_assign_in_loop(self): code = "a = 2\nwhile 1: a = 1\na" self.checkChains(code, ["a -> (a -> ())", "a -> (a -> ())"]) diff --git a/tests/test_definitions.py b/tests/test_definitions.py index 17f80e2..3e432a2 100644 --- a/tests/test_definitions.py +++ b/tests/test_definitions.py @@ -89,6 +89,11 @@ def testGlobalDestructuring(self): code = "x, y = 1, 2" self.checkGlobals(code, ["x", "y"]) + if sys.version_info.major >= 3: + def testGlobalStarredDestructuring(self): + code = "x, *y = 1, [2]" + self.checkGlobals(code, ["x", "y"]) + def testGlobalAugAssign(self): code = "x = 1; x += 2" self.checkGlobals(code, ["x"]) @@ -369,6 +374,10 @@ def test_LocalNonLocalAfter(self): ) self.checkLocals(code, ["a", "bar"]) + def test_LocalDestructuring(self): + code = "def foo(x): y, *z = x" + self.checkLocals(code, ["x", "y", "z"]) + def test_LocalMadeGlobal(self): code = "def foo(): global a; a = 1" self.checkLocals(code, []) @@ -468,4 +477,4 @@ def test_AssignmentSimple(self): a = a + 1 a = a + 1 """ - self.checkLiveLocals(code, ["a:4"], ["a:2,3,4"]) \ No newline at end of file + self.checkLiveLocals(code, ["a:4"], ["a:2,3,4"])