diff --git a/README.md b/README.md index 4329100..dfda506 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,16 @@ $ astpath -A 1 "//For/body[AugAssign/op/Add and count(child::*)=1]" | head -6 ./pstats.py:513 nc += calls ``` +_Finding classes matching a regular expression:_ +```bash +$ astpath "//ClassDef[re:match('.*Var', @name)]" | head -5 +./typing.py:452 > class TypeVar(_TypingBase, _root=True): +./typing.py:1366 > class _ClassVar(_FinalTypingBase, _root=True): +./tkinter/__init__.py:287 > class Variable: +./tkinter/__init__.py:463 > class StringVar(Variable): +./tkinter/__init__.py:485 > class IntVar(Variable): +``` + `astpath` can also be imported and used programmatically: ```python >>> from astpath import search @@ -112,6 +122,7 @@ pip install astpath * Official `ast` module documentation for [Python 2.7](https://docs.python.org/2.7/library/ast.html) and [Python 3.X](https://docs.python.org/3/library/ast.html). * [Python AST Explorer](https://python-ast-explorer.com/) for worked examples of ASTs. * A [brief guide to XPath](http://www.w3schools.com/xml/xpath_syntax.asp). +* [`bellybutton`](https://github.com/hchasestevens/bellybutton), a fully-featured linting engine built on `astpath`. ## Contacts diff --git a/astpath/asts.py b/astpath/asts.py index af48fa3..411b0a2 100644 --- a/astpath/asts.py +++ b/astpath/asts.py @@ -80,6 +80,13 @@ def convert_to_xml(node, omit_docstrings=False, node_mappings=None): ) elif field_value is not None: + ## add type attribute e.g. so we can distinguish strings from numbers etc + ## in older Python (< 3.8) type could be identified by Str vs Num and s vs n etc + ## e.g. + _set_encoded_literal( + partial(xml_node.set, 'type'), + type(field_value).__name__ + ) _set_encoded_literal( partial(xml_node.set, field_name), field_value diff --git a/astpath/cli.py b/astpath/cli.py index df4dc08..931d2f3 100644 --- a/astpath/cli.py +++ b/astpath/cli.py @@ -16,9 +16,9 @@ parser = argparse.ArgumentParser() -parser.add_argument('-s', '--hide-lines', help="hide source lines, showing only line numbers", action='store_true',) parser.add_argument('-q', '--quiet', help="hide output of matches", action='store_true',) parser.add_argument('-v', '--verbose', help="increase output verbosity", action='store_true',) +parser.add_argument('-x', '--xml', help="print only the matching XML elements", action='store_true',) parser.add_argument('-a', '--abspaths', help="show absolute paths", action='store_true',) parser.add_argument('-R', '--no-recurse', help="ignore subdirectories, searching only files in the specified directory", action='store_true',) parser.add_argument('-d', '--dir', help="search directory or file", default='.',) @@ -39,16 +39,16 @@ def main(): else: recurse = not args.no_recurse - before_context = args.context or args.before_context - after_context = args.context or args.after_context - if (before_context or after_context) and args.hide_lines: + before_context = args.before_context or args.context + after_context = args.after_context or args.context + if (before_context or after_context) and args.quiet: print("ERROR: Context cannot be specified when suppressing output.") exit(1) search( args.dir, ' '.join(args.expr), - show_lines=not args.hide_lines, + print_xml=args.xml, print_matches=not args.quiet, verbose=args.verbose, abspaths=args.abspaths, diff --git a/astpath/search.py b/astpath/search.py index 8fe2b7e..549a0a5 100644 --- a/astpath/search.py +++ b/astpath/search.py @@ -2,18 +2,26 @@ from __future__ import print_function -from itertools import islice +from itertools import islice, repeat import os +import re import ast from astpath.asts import convert_to_xml + +class XMLVersions: + LXML = object() + XML = object() + + try: from lxml.etree import tostring - XML_VERSION = 'lxml' + from lxml import etree + XML_VERSION = XMLVersions.LXML except ImportError: from xml.etree.ElementTree import tostring - XML_VERSION = 'xml' + XML_VERSION = XMLVersions.XML PYTHON_EXTENSION = '{}py'.format(os.path.extsep) @@ -26,7 +34,7 @@ def lxml_query(element, expression): def xml_query(element, expression): return element.findall(expression) - if XML_VERSION == 'lxml': + if XML_VERSION is XMLVersions.LXML: return lxml_query else: if verbose: @@ -38,26 +46,39 @@ def xml_query(element, expression): def _tostring_factory(): - def xml_tostring(*args, pretty_print=False, **kwargs): + def xml_tostring(*args, **kwargs): + kwargs.pop('pretty_print') return tostring(*args, **kwargs) - if XML_VERSION == 'lxml': + if XML_VERSION is XMLVersions.LXML: return tostring else: return xml_tostring -def find_in_ast(xml_ast, expr, return_lines=True, query=_query_factory(), node_mappings=None): - """ - Find items matching expression expr in an XML AST. +if XML_VERSION is XMLVersions.LXML: + regex_ns = etree.FunctionNamespace('https://github.com/hchasestevens/astpath') + regex_ns.prefix = 're' - If return_lines is True, return only matching line numbers, otherwise - returning XML nodes. - """ + @regex_ns + def match(ctx, pattern, strings): + return any( + re.match(pattern, s) is not None + for s in strings + ) + + @regex_ns + def search(ctx, pattern, strings): + return any( + re.search(pattern, s) is not None + for s in strings + ) + + +def find_in_ast(xml_ast, expr, query=_query_factory(), node_mappings=None): + """Find items matching expression expr in an XML AST.""" results = query(xml_ast, expr) - if return_lines: - return linenos_from_xml(results, query=query, node_mappings=node_mappings) - return results + return linenos_from_xml(results, query=query, node_mappings=node_mappings) def linenos_from_xml(elements, query=_query_factory(), node_mappings=None): @@ -72,7 +93,7 @@ def linenos_from_xml(elements, query=_query_factory(), node_mappings=None): # we're not using lxml backend if node_mappings is None: raise ValueError( - "Lines cannot be returned when using native" + "Lines cannot be returned when using native " "backend without `node_mappings` supplied." ) linenos = getattr(node_mappings[element], 'lineno', 0), @@ -104,9 +125,9 @@ def file_to_xml_ast(filename, omit_docstrings=False, node_mappings=None): def search( - directory, expression, print_matches=False, return_lines=True, - show_lines=True, verbose=False, abspaths=False, recurse=True, - before_context=0, after_context=0 + directory, expression, print_matches=False, print_xml=False, + verbose=False, abspaths=False, recurse=True, + before_context=0, after_context=0, extension=PYTHON_EXTENSION ): """ Perform a recursive search through Python files. @@ -114,9 +135,6 @@ def search( Only for files in the given directory for items matching the specified expression. """ - if show_lines and not return_lines: - raise ValueError("`return_lines` must be set if showing lines.") - query = _query_factory(verbose=verbose) if os.path.isfile(directory): @@ -126,20 +144,17 @@ def search( elif recurse: files = os.walk(directory) else: - files = (( - directory, - None, - [ - item for item in os.listdir(directory) - if os.path.isfile(os.path.join(directory, item)) - ] - ),) + files = ((directory, None, [ + item + for item in os.listdir(directory) + if os.path.isfile(os.path.join(directory, item)) + ]),) global_matches = [] for root, __, filenames in files: python_filenames = ( os.path.join(root, filename) for filename in filenames - if filename.endswith(PYTHON_EXTENSION) + if filename.endswith(extension) ) for filename in python_filenames: node_mappings = {} @@ -158,16 +173,18 @@ def search( )) continue # unparseable - file_matches = find_in_ast( - xml_ast, - expression, - return_lines=print_matches or return_lines, - query=query, - node_mappings=node_mappings, - ) + matching_elements = query(xml_ast, expression) + + if print_xml: + tostring = _tostring_factory() + for element in matching_elements: + print(tostring(xml_ast, pretty_print=True)) + + matching_lines = linenos_from_xml(matching_elements, query=query, node_mappings=node_mappings) + global_matches.extend(zip(repeat(filename), matching_lines)) - for match in file_matches: - if print_matches: + if print_matches: + for match in matching_lines: matching_lines = list(context( file_lines, match - 1, before_context, after_context )) @@ -176,15 +193,12 @@ def search( path=os.path.abspath(filename) if abspaths else filename, lineno=lineno, sep='>' if lineno == match - 1 else ' ', - line=line if show_lines else '', + line=line, )) if before_context or after_context: print() - else: - global_matches.append((filename, match)) - if not print_matches: - return global_matches + return global_matches def context(lines, index, before=0, after=0, both=0): diff --git a/setup.py b/setup.py index cb9b7e4..69e0b71 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='astpath', packages=['astpath'], - version='0.6.1', + version='0.9.1', description='A query language for Python abstract syntax trees', license='MIT', author='H. Chase Stevens',