Skip to content

Commit

Permalink
fix missing node normalization to match pytree.convert (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
terencehonles authored May 15, 2021
1 parent 7298175 commit e759da7
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 24 deletions.
80 changes: 56 additions & 24 deletions src/retype/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,73 +489,105 @@ def _sa_expr(expr):
return serialize_attribute(expr.value)


@singledispatch
def convert_annotation(ann):
return normalize_node(_convert_annotation(ann))


def normalize_node(node):
"""Normalizes nodes to match pytree.convert and flatten power nodes"""
if len(node.children) == 1:
return node.children[0]

node.children = [normalize_node(i) for i in node.children]

# if the node is a power node inline child power nodes
if node.type == syms.power:
children = []
for child in node.children:
if child.type == syms.power:
children.extend(child.children)
else:
children.append(child)

node.children = children

return node


@singledispatch
def _convert_annotation(ann):
"""Converts an AST object into its lib2to3 equivalent."""
raise NotImplementedError(f"unknown AST node type: {ann!r}")


@convert_annotation.register(ast3.Subscript)
@_convert_annotation.register(ast3.Subscript)
def _c_subscript(sub):
return Node(
syms.power,
[
convert_annotation(sub.value),
Node(syms.trailer, [new(_lsqb), convert_annotation(sub.slice), new(_rsqb)]),
_convert_annotation(sub.value),
Node(
syms.trailer, [new(_lsqb), _convert_annotation(sub.slice), new(_rsqb)]
),
],
)


@convert_annotation.register(ast3.Name)
@_convert_annotation.register(ast3.Name)
def _c_name(name):
return Leaf(token.NAME, name.id)


@convert_annotation.register(ast3.NameConstant)
@_convert_annotation.register(ast3.NameConstant)
def _c_nameconstant(const):
return Leaf(token.NAME, repr(const.value))


@convert_annotation.register(ast3.Ellipsis)
@_convert_annotation.register(ast3.Ellipsis)
def _c_ellipsis(ell):
return Node(syms.atom, [new(_dot), new(_dot), new(_dot)])


@convert_annotation.register(ast3.Str)
@_convert_annotation.register(ast3.Str)
def _c_str(s):
return Leaf(token.STRING, repr(s.s))


@convert_annotation.register(ast3.Num)
@_convert_annotation.register(ast3.Num)
def _c_num(n):
return Leaf(token.NUMBER, repr(n.n))


@convert_annotation.register(ast3.Index)
@_convert_annotation.register(ast3.Index)
def _c_index(index):
return convert_annotation(index.value)
return _convert_annotation(index.value)


@convert_annotation.register(ast3.Tuple)
@_convert_annotation.register(ast3.Tuple)
def _c_tuple(tup):
contents = [convert_annotation(elt) for elt in tup.elts]
contents = [_convert_annotation(elt) for elt in tup.elts]
for index in range(len(contents) - 1, 0, -1):
contents[index].prefix = " "
contents.insert(index, new(_comma))

return Node(syms.subscriptlist, contents)


@convert_annotation.register(ast3.Attribute)
@_convert_annotation.register(ast3.Attribute)
def _c_attribute(attr):
# This is hacky. ¯\_(ツ)_/¯
return Leaf(token.NAME, f"{convert_annotation(attr.value)}.{attr.attr}")
return Node(
syms.power,
[
_convert_annotation(attr.value),
Node(syms.trailer, [new(_dot), Leaf(token.NAME, attr.attr)]),
],
)


@convert_annotation.register(ast3.Call)
@_convert_annotation.register(ast3.Call)
def _c_call(call):
contents = [convert_annotation(arg) for arg in call.args]
contents.extend(convert_annotation(kwarg) for kwarg in call.keywords)
contents = [_convert_annotation(arg) for arg in call.args]
contents.extend(_convert_annotation(kwarg) for kwarg in call.keywords)
for index in range(len(contents) - 1, 0, -1):
contents[index].prefix = " "
contents.insert(index, new(_comma))
Expand All @@ -564,26 +596,26 @@ def _c_call(call):
if contents:
call_args.insert(1, Node(syms.arglist, contents))
return Node(
syms.power, [convert_annotation(call.func), Node(syms.trailer, call_args)]
syms.power, [_convert_annotation(call.func), Node(syms.trailer, call_args)]
)


@convert_annotation.register(ast3.keyword)
@_convert_annotation.register(ast3.keyword)
def _c_keyword(kwarg):
assert kwarg.arg
return Node(
syms.argument,
[
Leaf(token.NAME, kwarg.arg),
new(_eq, prefix=""),
convert_annotation(kwarg.value),
_convert_annotation(kwarg.value),
],
)


@convert_annotation.register(ast3.List)
@_convert_annotation.register(ast3.List)
def _c_list(l):
contents = [convert_annotation(elt) for elt in l.elts]
contents = [_convert_annotation(elt) for elt in l.elts]
for index in range(len(contents) - 1, 0, -1):
contents[index].prefix = " "
contents.insert(index, new(_comma))
Expand Down
35 changes: 35 additions & 0 deletions tests/test_retype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2259,6 +2259,41 @@ def __init__(self, a1: C, **kwargs) -> None:
self.assertReapplyVisible(pyi_txt, src_txt, expected_txt)


class TypeAliasTestCase(RetypeTestCase):
def test_type_alias_import(self) -> None:
pyi_txt = src_txt = expected_txt = """
import typing
OPTIONAL_STR = typing.Optional[str]
"""
self.assertReapplyVisible(pyi_txt, src_txt, expected_txt)

def test_type_alias_from_import(self) -> None:
pyi_txt = src_txt = expected_txt = """
from typing import Optional
OPTIONAL_STR = Optional[str]
"""
self.assertReapplyVisible(pyi_txt, src_txt, expected_txt)


class NormalizationTestCase(RetypeTestCase):
def test_simple_function(self) -> None:
pyi_txt = src_txt = expected_txt = """
def test(fn: Callable[[str], None]) -> None: ...
"""
self.assertReapplyVisible(pyi_txt, src_txt, expected_txt)

def test_more_complex_function(self) -> None:
pyi_txt = src_txt = expected_txt = """
def build(tmp_path: Path) -> Callable[[Dict[str, str]], Path]:
def _build(files: Dict[str, str]) -> Path:
...
return _build
"""
self.assertReapplyVisible(pyi_txt, src_txt, expected_txt)


class SerializeTestCase(RetypeTestCase):
def test_serialize_attribute(self) -> None:
src_txt = "a.b.c"
Expand Down
1 change: 1 addition & 0 deletions types/src/retype/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def decorator_names(
def names_already_imported(
names: Union[List[ast3.AST], ast3.AST], node: Node
) -> bool: ...
def _convert_annotation(ann: ast3.AST) -> _LN: ...
def convert_annotation(ann: ast3.AST) -> _LN: ...
def serialize_attribute(attr: ast3.AST) -> str: ...
def reapply(
Expand Down

0 comments on commit e759da7

Please sign in to comment.