diff --git a/src/retype/__init__.py b/src/retype/__init__.py index 951b507..c28b596 100644 --- a/src/retype/__init__.py +++ b/src/retype/__init__.py @@ -489,56 +489,83 @@ def _sa_expr(expr): return serialize_attribute(expr.value) -@singledispatch def convert_annotation(ann): + return normalize_node(_convert_annotation(ann)) + + +def normalize_node(node): + """Normalizes nodes to match pytree.convert and flatten power nodes""" + if len(node.children) == 1: + return node.children[0] + + node.children = [normalize_node(i) for i in node.children] + + # if the node is a power node inline child power nodes + if node.type == syms.power: + children = [] + for child in node.children: + if child.type == syms.power: + children.extend(child.children) + else: + children.append(child) + + node.children = children + + return node + + +@singledispatch +def _convert_annotation(ann): """Converts an AST object into its lib2to3 equivalent.""" raise NotImplementedError(f"unknown AST node type: {ann!r}") -@convert_annotation.register(ast3.Subscript) +@_convert_annotation.register(ast3.Subscript) def _c_subscript(sub): return Node( syms.power, [ - convert_annotation(sub.value), - Node(syms.trailer, [new(_lsqb), convert_annotation(sub.slice), new(_rsqb)]), + _convert_annotation(sub.value), + Node( + syms.trailer, [new(_lsqb), _convert_annotation(sub.slice), new(_rsqb)] + ), ], ) -@convert_annotation.register(ast3.Name) +@_convert_annotation.register(ast3.Name) def _c_name(name): return Leaf(token.NAME, name.id) -@convert_annotation.register(ast3.NameConstant) +@_convert_annotation.register(ast3.NameConstant) def _c_nameconstant(const): return Leaf(token.NAME, repr(const.value)) -@convert_annotation.register(ast3.Ellipsis) +@_convert_annotation.register(ast3.Ellipsis) def _c_ellipsis(ell): return Node(syms.atom, [new(_dot), new(_dot), new(_dot)]) -@convert_annotation.register(ast3.Str) +@_convert_annotation.register(ast3.Str) def _c_str(s): return Leaf(token.STRING, repr(s.s)) -@convert_annotation.register(ast3.Num) +@_convert_annotation.register(ast3.Num) def _c_num(n): return Leaf(token.NUMBER, repr(n.n)) -@convert_annotation.register(ast3.Index) +@_convert_annotation.register(ast3.Index) def _c_index(index): - return convert_annotation(index.value) + return _convert_annotation(index.value) -@convert_annotation.register(ast3.Tuple) +@_convert_annotation.register(ast3.Tuple) def _c_tuple(tup): - contents = [convert_annotation(elt) for elt in tup.elts] + contents = [_convert_annotation(elt) for elt in tup.elts] for index in range(len(contents) - 1, 0, -1): contents[index].prefix = " " contents.insert(index, new(_comma)) @@ -546,16 +573,21 @@ def _c_tuple(tup): return Node(syms.subscriptlist, contents) -@convert_annotation.register(ast3.Attribute) +@_convert_annotation.register(ast3.Attribute) def _c_attribute(attr): - # This is hacky. ¯\_(ツ)_/¯ - return Leaf(token.NAME, f"{convert_annotation(attr.value)}.{attr.attr}") + return Node( + syms.power, + [ + _convert_annotation(attr.value), + Node(syms.trailer, [new(_dot), Leaf(token.NAME, attr.attr)]), + ], + ) -@convert_annotation.register(ast3.Call) +@_convert_annotation.register(ast3.Call) def _c_call(call): - contents = [convert_annotation(arg) for arg in call.args] - contents.extend(convert_annotation(kwarg) for kwarg in call.keywords) + contents = [_convert_annotation(arg) for arg in call.args] + contents.extend(_convert_annotation(kwarg) for kwarg in call.keywords) for index in range(len(contents) - 1, 0, -1): contents[index].prefix = " " contents.insert(index, new(_comma)) @@ -564,11 +596,11 @@ def _c_call(call): if contents: call_args.insert(1, Node(syms.arglist, contents)) return Node( - syms.power, [convert_annotation(call.func), Node(syms.trailer, call_args)] + syms.power, [_convert_annotation(call.func), Node(syms.trailer, call_args)] ) -@convert_annotation.register(ast3.keyword) +@_convert_annotation.register(ast3.keyword) def _c_keyword(kwarg): assert kwarg.arg return Node( @@ -576,14 +608,14 @@ def _c_keyword(kwarg): [ Leaf(token.NAME, kwarg.arg), new(_eq, prefix=""), - convert_annotation(kwarg.value), + _convert_annotation(kwarg.value), ], ) -@convert_annotation.register(ast3.List) +@_convert_annotation.register(ast3.List) def _c_list(l): - contents = [convert_annotation(elt) for elt in l.elts] + contents = [_convert_annotation(elt) for elt in l.elts] for index in range(len(contents) - 1, 0, -1): contents[index].prefix = " " contents.insert(index, new(_comma)) diff --git a/tests/test_retype.py b/tests/test_retype.py index 95c8e8f..9e4380a 100644 --- a/tests/test_retype.py +++ b/tests/test_retype.py @@ -2259,6 +2259,41 @@ def __init__(self, a1: C, **kwargs) -> None: self.assertReapplyVisible(pyi_txt, src_txt, expected_txt) +class TypeAliasTestCase(RetypeTestCase): + def test_type_alias_import(self) -> None: + pyi_txt = src_txt = expected_txt = """ + import typing + + OPTIONAL_STR = typing.Optional[str] + """ + self.assertReapplyVisible(pyi_txt, src_txt, expected_txt) + + def test_type_alias_from_import(self) -> None: + pyi_txt = src_txt = expected_txt = """ + from typing import Optional + + OPTIONAL_STR = Optional[str] + """ + self.assertReapplyVisible(pyi_txt, src_txt, expected_txt) + + +class NormalizationTestCase(RetypeTestCase): + def test_simple_function(self) -> None: + pyi_txt = src_txt = expected_txt = """ + def test(fn: Callable[[str], None]) -> None: ... + """ + self.assertReapplyVisible(pyi_txt, src_txt, expected_txt) + + def test_more_complex_function(self) -> None: + pyi_txt = src_txt = expected_txt = """ + def build(tmp_path: Path) -> Callable[[Dict[str, str]], Path]: + def _build(files: Dict[str, str]) -> Path: + ... + return _build + """ + self.assertReapplyVisible(pyi_txt, src_txt, expected_txt) + + class SerializeTestCase(RetypeTestCase): def test_serialize_attribute(self) -> None: src_txt = "a.b.c" diff --git a/types/src/retype/__init__.pyi b/types/src/retype/__init__.pyi index 8535578..7b9ac6c 100644 --- a/types/src/retype/__init__.pyi +++ b/types/src/retype/__init__.pyi @@ -107,6 +107,7 @@ def decorator_names( def names_already_imported( names: Union[List[ast3.AST], ast3.AST], node: Node ) -> bool: ... +def _convert_annotation(ann: ast3.AST) -> _LN: ... def convert_annotation(ann: ast3.AST) -> _LN: ... def serialize_attribute(attr: ast3.AST) -> str: ... def reapply(