Skip to content

Commit

Permalink
Suport the standard library with the dynamic pkg() function way.
Browse files Browse the repository at this point in the history
  • Loading branch information
tristanlatr committed Sep 16, 2024
2 parents 0d8ba55 + 6f02909 commit 76afa83
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 101 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ beniget.egg-info

.pytest_cache
.vscode
.tox
build
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ It's a building block to write static analyzer or compiler for Python.
Beniget relies on `gast <https://pypi.org/project/gast/>`_ to provide a cross
version abstraction of the AST, effectively working across all Python 3 versions greater than 3.6.

Since version 0.5.0, beniget works with the standard library `ast <https://docs.python.org/3/library/ast.html#module-ast>`_ as well 🥳!

API
---

Expand Down
171 changes: 126 additions & 45 deletions beniget/beniget.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import defaultdict
from contextlib import contextmanager
import builtins
import sys

import gast
Expand All @@ -10,11 +11,11 @@
def pkg(node):
return sys.modules[type(node).__module__]


# NodeVisitor is compatible with standard library ast
class Ancestors(gast.NodeVisitor):
"""
Build the ancestor tree, that associates a node to the list of node visited
from the root node (the Module) to the current node
from the root node (the Module) to the current node.
Example usage with gast module
>>> from beniget import Ancestors
Expand All @@ -34,7 +35,7 @@ class Ancestors(gast.NodeVisitor):
<class 'gast.gast.Return'>
>>> import ast
>>> code = 'def foo(x): return x + 1'
>>> module = ast.parse(code)
>>> ancestors = Ancestors()
Expand Down Expand Up @@ -76,7 +77,7 @@ def parentFunction(self, node):
ast.AsyncFunctionDef))

def parentStmt(self, node):
return self.parentInstance(node, gast.stmt)
return self.parentInstance(node, gast.stmt) # gast.stmt and ast.stmt are the same.

_novalue = object()
@contextmanager
Expand All @@ -95,6 +96,9 @@ def _rename_attrs(obj, **attrs):
else:
setattr(obj, k, v)

_PY310PLUS = sys.version_info >= (3, 10)
_PY38PLUS = sys.version_info >= (3, 8)

class Def(object):
"""
Model a definition, either named or unnamed, and its users.
Expand Down Expand Up @@ -129,15 +133,19 @@ def name(self):
return self.node.name
elif isinstance(self.node, ast.Name):
return self.node.id
elif ast is not gast.gast and isinstance(self.node, ast.arg):
return self.node.arg
elif ast is not gast.gast and isinstance(self.node, ast.ExceptHandler):
return self.node.name
elif isinstance(self.node, ast.alias):
base = self.node.name.split(".", 1)[0]
return self.node.asname or base
elif isinstance(self.node, (ast.MatchStar, ast.MatchAs)):
if self.node.name:
return self.node.name
elif isinstance(self.node, ast.MatchMapping):
if self.node.rest:
return self.node.rest
elif _PY310PLUS and isinstance(self.node, (ast.MatchStar, ast.MatchAs)) \
and isinstance(self.node.name, str):
return self.node.name
elif _PY310PLUS and isinstance(self.node, ast.MatchMapping) \
and self.node.rest:
return self.node.rest
elif isinstance(self.node, ast.Attribute):
return "." + self.node.attr
elif isinstance(self.node, tuple):
Expand Down Expand Up @@ -182,7 +190,8 @@ def _str(self, nodes):
BuiltinsSrc = builtins.__dict__

Builtins = {k: v for k, v in BuiltinsSrc.items()}

# not sure why we override the __file__ attribute?
# this should probably be assigned to the filename give to DefUseChains instead.
Builtins["__file__"] = __file__

DeclarationStep, DefinitionStep = object(), object()
Expand All @@ -208,7 +217,10 @@ class _CollectFutureImports(gast.NodeVisitor):
# - other future statements.
# as soon as we're visiting something else, we can stop the visit.
def __init__(self):
self.FutureImports = set() #type:set[str]
self.FutureImports = set()

# Compat
self.visit_Str = lambda v: None

def visit_Module(self, node):
for child in node.body:
Expand Down Expand Up @@ -281,6 +293,7 @@ def collect_locals(node):
visitor.generic_visit(node)
return visitor.Locals

_DeclarationStep, _DefinitionStep = object(), object()

class DefUseChains(gast.NodeVisitor):
"""
Expand All @@ -304,13 +317,13 @@ class DefUseChains(gast.NodeVisitor):
One instance of DefUseChains is only suitable to analyse one AST Module in it's lifecycle.
"""


def __init__(self, filename=None):
"""
- filename: str, included in error messages if specified
"""
self.chains = {}
self.locals = defaultdict(list)

self.filename = filename

# deep copy of builtins, to remain reentrant
Expand Down Expand Up @@ -351,12 +364,12 @@ def __init__(self, filename=None):
# dead code levels, it's non null for code that cannot be executed
self._deadcode = 0

# attributes set in visit_Module
# attributes (re)set in visit_Module
self.module = None
self.future_annotations = False

#
## helpers
## test helpers
#
def _dump_locals(self, node, only_live=False):
"""
Expand All @@ -368,7 +381,8 @@ def _dump_locals(self, node, only_live=False):
for d in self.locals[node]:
if not only_live or d.islive:
groupped[d.name()].append(d)
return ['{}:{}'.format(name, ','.join([str(getattr(d.node, 'lineno', -1)) for d in defs])) \
# Compat: the linenumber is None on gast when unset in ast, so use None when unset.
return ['{}:{}'.format(name, ','.join([str(getattr(d.node, 'lineno', None)) for d in defs])) \
for name,defs in groupped.items()]

def dump_definitions(self, node, ignore_builtins=True):
Expand All @@ -392,7 +406,8 @@ def location(self, node):
)
return " at {}{}:{}".format(filename,
node.lineno,
node.col_offset)
# Compat: Not all stdlib nodes have a col_offset
getattr(node, 'col_offset', None))
else:
return ""

Expand Down Expand Up @@ -575,7 +590,7 @@ def process_functions_bodies(self):
visitor = getattr(self,
"visit_{}".format(type(fnode).__name__))
with self.SwitchScopeContext(defs, scopes, scope_depths, precomputed_locals):
visitor(fnode, step=DefinitionStep)
visitor(fnode, step=_DefinitionStep)

def process_annotations(self):
compute_defs, self.defs = self.defs, self.compute_annotation_defs
Expand All @@ -586,17 +601,76 @@ def process_annotations(self):
cb(visitor(annnode)) if cb else visitor(annnode)
self._scopes = currenthead
self.defs = compute_defs

def _support_stdlib(self):
# Support ast.arg instances
_visit_skip_annotation = self.visit_skip_annotation

def visit_skip_annotation(node):
if isinstance(node, pkg(node).arg):
return self.visit_arg(node, skip_annotation=True)
return _visit_skip_annotation(node)

self.visit_skip_annotation = visit_skip_annotation

def visit_arg(node, skip_annotation=False):
dnode = self.chains.setdefault(node, Def(node))
self.set_definition(node.arg, dnode)
if dnode not in self.locals[self._scopes[-1]]:
self.locals[self._scopes[-1]].append(dnode)
if node.annotation is not None and not skip_annotation:
self.visit(node.annotation)
return dnode

self.visit_arg = visit_arg

# In gast, the name field of ExceptHandler is represented as an ast.Name
# with an ast.Store context and not a str; so for the standard library we use the
# ExceptHandler node as reference point.

def visit_ExceptHandler(node):
if isinstance(node.name, str):
dnode = self.chains.setdefault(node, Def(node))
self.set_definition(node.name, dnode)
if dnode not in self.locals[self._scopes[-1]]:
self.locals[self._scopes[-1]].append(dnode)
self.generic_visit(node)

self.visit_ExceptHandler = visit_ExceptHandler

def visit_ExtSlice(node):
dnode = self.chains.setdefault(node, Def(node))
for elt in node.dims:
self.visit(elt).add_user(dnode)
return dnode

self.visit_ExtSlice = visit_ExtSlice

def visit_Index(node):
# pretend Index does not exist
return self.visit(node.value)

self.visit_Index = visit_Index

self.visit_NameConstant = self.visit_Num = self.visit_Str = \
self.visit_Bytes = self.visit_Ellipsis = self.visit_Constant

# stmt
def visit_Module(self, node):
# Compat
if not isinstance(node, gast.Module):
# If it's not a gast Module it must be a standard library module,
# so dynamically adjust the class to support it.
self._support_stdlib()

# save module node
self.module = node

futures = collect_future_imports(node)
# determine whether the PEP563 is enabled
# allow manual enabling of DefUseChains.future_annotations
self.future_annotations |= 'annotations' in futures


with self.ScopeContext(node):


Expand Down Expand Up @@ -646,7 +720,7 @@ def set_definition(self, name, dnode_or_dnodes, index=-1):

# set the islive flag to False on killed Defs
for d in self._definitions[index].get(name, ()):
if not isinstance(d.node, gast.AST):
if not isinstance(d.node, gast.AST): # gast.AST and ast.AST are the same.
# A builtin: we never explicitely mark the builtins as killed, since
# it can be easily deducted.
continue
Expand Down Expand Up @@ -706,8 +780,8 @@ def visit_skip_annotation(self, node):
else:
self.visit(node)

def visit_FunctionDef(self, node, step=DeclarationStep):
if step is DeclarationStep:
def visit_FunctionDef(self, node, step=_DeclarationStep):
if step is _DeclarationStep:
dnode = self.chains.setdefault(node, Def(node))
self.add_to_locals(node.name, dnode)

Expand Down Expand Up @@ -743,7 +817,7 @@ def visit_FunctionDef(self, node, step=DeclarationStep):
list(self._scopes),
list(self._scope_depths),
list(self._precomputed_locals)))
elif step is DefinitionStep:
elif step is _DefinitionStep:
with self.ScopeContext(node):
for arg in _iter_arguments(node.args):
self.visit_skip_annotation(arg)
Expand Down Expand Up @@ -1032,8 +1106,8 @@ def visit_UnaryOp(self, node):
self.visit(node.operand).add_user(dnode)
return dnode

def visit_Lambda(self, node, step=DeclarationStep):
if step is DeclarationStep:
def visit_Lambda(self, node, step=_DeclarationStep):
if step is _DeclarationStep:
dnode = self.chains.setdefault(node, Def(node))
for default in node.args.defaults:
self.visit(default).add_user(dnode)
Expand All @@ -1044,7 +1118,7 @@ def visit_Lambda(self, node, step=DeclarationStep):
list(self._scope_depths),
list(self._precomputed_locals)))
return dnode
elif step is DefinitionStep:
elif step is _DefinitionStep:
dnode = self.chains[node]
with self.ScopeContext(node):
for a in _iter_arguments(node.args):
Expand Down Expand Up @@ -1218,8 +1292,8 @@ def visit_Name(self, node, skip_annotation=False, named_expr=False):
if dnode not in self.locals[self._scopes[index]]:
self.locals[self._scopes[index]].append(dnode)

# Name.annotation is a special case because of gast
if node.annotation is not None and not skip_annotation and not self.future_annotations:
# Compat: Name.annotation is a special case because of gast
if getattr(node, 'annotation', None) is not None and not skip_annotation and not self.future_annotations:
self.visit(node.annotation)


Expand Down Expand Up @@ -1422,6 +1496,8 @@ def _validate_comprehension(node):
- a named expression is used in a comprehension iterable expression
- a named expression rebinds a comprehension iteration variable
"""
if not _PY38PLUS:
return
ast = pkg(node)
iter_names = set() # comprehension iteration variables
for gen in node.generators:
Expand All @@ -1436,21 +1512,6 @@ def _validate_comprehension(node):
raise SyntaxError('assignment expression cannot rebind '
"comprehension iteration variable '{}'".format(bound))

def _iter_arguments(args):
"""
Yields all arguments of the given ast.arguments instance.
"""
for arg in args.args:
yield arg
for arg in args.posonlyargs:
yield arg
if args.vararg:
yield args.vararg
for arg in args.kwonlyargs:
yield arg
if args.kwarg:
yield args.kwarg

def lookup_annotation_name_defs(name, heads, locals_map):
r"""
Simple identifier -> defs resolving.
Expand Down Expand Up @@ -1543,6 +1604,23 @@ def _lookup(name, scopes, locals_map):
raise LookupError()
return _lookup(name, scopes, locals_map)

def _iter_arguments(args):
"""
Yields all arguments of the given ast.arguments instance.
"""
for arg in args.args:
yield arg
# Compat: This method is used for stdlib nodes as well before 3.8.
# Should bengiget still support Python 3.7 ?
for arg in getattr(args, 'posonlyargs', ()):
yield arg
if args.vararg:
yield args.vararg
for arg in args.kwonlyargs:
yield arg
if args.kwarg:
yield args.kwarg

class UseDefChains(object):
"""
DefUseChains adaptor that builds a mapping between each user
Expand All @@ -1551,10 +1629,13 @@ class UseDefChains(object):
that define it.
"""

def __init__(self, defuses):
def __init__(self, defuses: DefUseChains):
self.chains = {}

# TODO: why does this doesn't include functions and classes?
for chain in defuses.chains.values():
if isinstance(chain.node, pkg(chain.node).Name):
if isinstance(chain.node, pkg(chain.node).Name): # TODO: what about arguments ?
# they will included for gast but not for stdlib ast since arg in gast are Name.
self.chains.setdefault(chain.node, [])
for use in chain.users():
self.chains.setdefault(use.node, []).append(chain)
Expand Down
3 changes: 3 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
Each TestCase subclass should have it's standard library counterpart.
"""
import tests.test_definitions
import tests.test_chains
import tests.test_capture
Expand Down
Loading

0 comments on commit 76afa83

Please sign in to comment.