Skip to content

Commit

Permalink
fix string normalization (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
terencehonles authored Oct 4, 2021
1 parent b9b6968 commit b4b19dc
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 25 deletions.
25 changes: 19 additions & 6 deletions src/retype/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,10 +495,12 @@ def convert_annotation(ann):

def normalize_node(node):
"""Normalizes nodes to match pytree.convert and flatten power nodes"""
if len(node.children) == 1:
return node.children[0]

node.children = [normalize_node(i) for i in node.children]
if isinstance(node, Node):
if len(node.children) == 1:
return node.children[0]

node.children = [normalize_node(i) for i in node.children]

# if the node is a power node inline child power nodes
if node.type == syms.power:
Expand Down Expand Up @@ -1143,6 +1145,16 @@ def copy_type_comment_to_annotation(arg):
arg.annotation = ann


def normalize_strings_to_repr(node):
"""Normalize string leaf nodes to a repr since that is what we generate."""
if node.type == token.STRING:
node.value = repr(ast3.literal_eval(node.value))
elif isinstance(node, Node):
node.children = [normalize_strings_to_repr(i) for i in node.children]

return node


def maybe_replace_any_if_equal(name, expected, actual, flags):
"""Return the type given in `expected`.
Expand All @@ -1157,10 +1169,11 @@ def maybe_replace_any_if_equal(name, expected, actual, flags):
2. We want people to be able to explicitly state that they want Any without it
being replaced. This way they can use an alias.
"""
is_equal = expected == actual
normalized = normalize_strings_to_repr(actual.clone())
is_equal = expected == normalized
if not is_equal and flags.replace_any:
actual_str = minimize_whitespace(str(actual))
if actual_str and actual_str[0] in {'"', "'"}:
actual_str = minimize_whitespace(str(normalized))
if actual.type == token.STRING and actual_str:
actual_str = actual_str[1:-1]
is_equal = actual_str in {"Any", "typing.Any", "t.Any"}

Expand Down
51 changes: 33 additions & 18 deletions tests/test_retype.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,30 +53,33 @@ def test_can_run_against_current_directory(tmp_path):
class RetypeTestCase(TestCase):
maxDiff = None

def assertReapply(
self, pyi_txt, src_txt, expected_txt, *, incremental=False, replace_any=False
):
def reapply(self, pyi_txt, src_txt, *, incremental=False, replace_any=False):
pyi = ast3.parse(dedent(pyi_txt))
src = lib2to3_parse(dedent(src_txt))
expected = lib2to3_parse(dedent(expected_txt))
assert isinstance(pyi, ast3.Module)
flags = ReApplyFlags(replace_any=replace_any, incremental=incremental)
reapply_all(pyi.body, src, flags)
fix_remaining_type_comments(src, flags)
return pyi, src

def assertReapply(
self, pyi_txt, src_txt, expected_txt, *, incremental=False, replace_any=False
):
self.longMessage = False
expected = lib2to3_parse(dedent(expected_txt))
pyi, src = self.reapply(
pyi_txt, src_txt, incremental=incremental, replace_any=replace_any
)
self.assertEqual(expected, src, f"\n{expected!r} != \n{src!r}")

def assertReapplyVisible(
self, pyi_txt, src_txt, expected_txt, *, incremental=False, replace_any=False
):
flags = ReApplyFlags(replace_any=replace_any, incremental=incremental)
pyi = ast3.parse(dedent(pyi_txt))
src = lib2to3_parse(dedent(src_txt))
expected = lib2to3_parse(dedent(expected_txt))
assert isinstance(pyi, ast3.Module)
reapply_all(pyi.body, src, flags)
fix_remaining_type_comments(src, flags)
self.longMessage = False
expected = lib2to3_parse(dedent(expected_txt))
pyi, src = self.reapply(
pyi_txt, src_txt, incremental=incremental, replace_any=replace_any
)
self.assertEqual(
str(expected), str(src), f"\n{str(expected)!r} != \n{str(src)!r}"
)
Expand All @@ -90,14 +93,10 @@ def assertReapplyRaises(
incremental=False,
replace_any=False,
):
flags = ReApplyFlags(replace_any=replace_any, incremental=incremental)

with self.assertRaises(expected_exception) as ctx:
pyi = ast3.parse(dedent(pyi_txt))
src = lib2to3_parse(dedent(src_txt))
assert isinstance(pyi, ast3.Module)
reapply_all(pyi.body, src, flags)
fix_remaining_type_comments(src, flags)
self.reapply(
pyi_txt, src_txt, incremental=incremental, replace_any=replace_any
)
return ctx.exception


Expand Down Expand Up @@ -2293,6 +2292,22 @@ def _build(files: Dict[str, str]) -> Path:
"""
self.assertReapplyVisible(pyi_txt, src_txt, expected_txt)

def test_strings_can_be_single_or_double_quotes(self) -> None:
pyi_txt = expected_txt = """
MODE = Literal['r', 'rb', 'w', 'wb']
E = TypeVar('_E', bound=Exception)
"""
src_txt = """
MODE = Literal["r", "rb", "w", "wb"]
E = TypeVar("_E", bound=Exception)
"""
self.assertReapplyVisible(pyi_txt, src_txt, expected_txt)

# allow the opposite quoting, but currently strings are always
# normalized to their repr form
pyi_txt, src_txt = src_txt, pyi_txt
self.assertReapplyVisible(pyi_txt, src_txt, expected_txt)


class SerializeTestCase(RetypeTestCase):
def test_serialize_attribute(self) -> None:
Expand Down
12 changes: 11 additions & 1 deletion types/tests/test_retype.pyi
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
from typing import Type, TypeVar
from lib2to3.pytree import Leaf, Node
from typing import Tuple, Type, TypeVar, Union
from unittest import TestCase

_E = TypeVar("_E", bound=Exception)
_LN = Union[Node, Leaf]

class RetypeTestCase(TestCase):
def reapply(
self,
pyi_txt: str,
src_txt: str,
*,
incremental: bool = ...,
replace_any: bool = ...,
) -> Tuple[_LN, _LN]: ...
def assertReapply(
self,
pyi_txt: str,
Expand Down

0 comments on commit b4b19dc

Please sign in to comment.